.. _program_listing_file__tmp_ws_src_proxsuite_include_proxsuite_linalg_dense_update.hpp: Program Listing for File update.hpp =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``/tmp/ws/src/proxsuite/include/proxsuite/linalg/dense/update.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp // // Copyright (c) 2022 INRIA // #ifndef PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP #define PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP #include "proxsuite/linalg/dense/core.hpp" namespace proxsuite { namespace linalg { namespace dense { namespace _detail { inline auto bytes_to_prev_aligned(void* ptr, usize align) noexcept -> isize { using UPtr = std::uintptr_t; UPtr mask = align - 1; UPtr iptr = UPtr(ptr); UPtr aligned_ptr = iptr & ~mask; return isize(aligned_ptr - iptr); } inline auto bytes_to_next_aligned(void* ptr, usize align) noexcept -> isize { using UPtr = std::uintptr_t; UPtr mask = align - 1; UPtr iptr = UPtr(ptr); UPtr aligned_ptr = (iptr + mask) & ~mask; return isize(aligned_ptr - iptr); } template VEG_INLINE void unroll_impl(proxsuite::linalg::veg::meta::index_sequence /*unused*/, Fn fn) { VEG_EVAL_ALL(fn(Is)); } template VEG_INLINE void unroll(Fn fn) { _detail::unroll_impl(proxsuite::linalg::veg::meta::make_index_sequence{}, VEG_FWD(fn)); } template struct RankUpdateLoadW { _simd::Pack* p_wr; T const* pw; isize w_stride; VEG_INLINE void operator()(usize i) const { p_wr[i] = _simd::Pack::load_unaligned(pw + w_stride * isize(i)); } }; template struct RankUpdateUpdateWAndL { _simd::Pack* p_wr; _simd::Pack& p_in_l; _simd::Pack const* p_p; _simd::Pack const* p_mu; VEG_INLINE void operator()(usize i) const { p_wr[i] = _simd::Pack::fnmadd(p_p[i], p_in_l, p_wr[i]); p_in_l = _simd::Pack::fmadd(p_mu[i], p_wr[i], p_in_l); } }; template struct RankUpdateStoreW { _simd::Pack const* p_wr; T* pw; isize w_stride; VEG_INLINE void operator()(usize i) const { p_wr[i].store_unaligned(pw + w_stride * isize(i)); } }; template VEG_INLINE void rank_r_update_inner_loop_iter( // _simd::Pack const* p_p, _simd::Pack const* p_mu, T* inout_l, T* pw, isize w_stride) { _simd::Pack p_wr[R]; _detail::unroll(RankUpdateLoadW{ p_wr, pw, w_stride }); _simd::Pack p_in_l = _simd::Pack::load_unaligned(inout_l); _detail::unroll(RankUpdateUpdateWAndL{ p_wr, p_in_l, p_p, p_mu }); _detail::unroll(RankUpdateStoreW{ p_wr, pw, w_stride }); p_in_l.store_unaligned(inout_l); } template struct RankRUpdateLoopImpl; template struct RankUpdateLoadPMu { _simd::Pack* p_p; _simd::Pack* p_mu; T const* p; T const* mu; VEG_INLINE void operator()(usize i) const { p_p[i] = _simd::Pack::broadcast(p[i]); p_mu[i] = _simd::Pack::broadcast(mu[i]); } }; template<> struct RankRUpdateLoopImpl { template VEG_INLINE static void fn(isize n, T* inout_l, T* pw, isize w_stride, T const* p, T const* mu) noexcept { using Pack_ = _simd::Pack; Pack_ p_p[R]; Pack_ p_mu[R]; _detail::unroll(RankUpdateLoadPMu{ p_p, p_mu, p, mu }); auto inout_l_finish = inout_l + n; while (inout_l < inout_l_finish) { _detail::rank_r_update_inner_loop_iter( p_p, p_mu, inout_l, pw, w_stride); ++inout_l; ++pw; } } }; template<> struct RankRUpdateLoopImpl { template VEG_INLINE static void fn(isize n, T* inout_l, T* pw, isize w_stride, T const* p, T const* mu) noexcept { // best perf if beginning of each pw is aligned // should be enforced by the Ldlt class using Info = _simd::NativePackInfo; constexpr usize N = Info::N; auto inout_l_vectorized_end = inout_l + usize(n) / N * N; auto inout_l_end = inout_l + usize(n); { using Pack = _simd::NativePack; Pack p_p[R]; Pack p_mu[R]; _detail::unroll(RankUpdateLoadPMu{ p_p, p_mu, p, mu }); while (inout_l < inout_l_vectorized_end) { _detail::rank_r_update_inner_loop_iter( p_p, p_mu, inout_l, pw, w_stride); inout_l += N; pw += N; } } { using Pack_ = _simd::Pack; Pack_ p_p[R]; Pack_ p_mu[R]; _detail::unroll(RankUpdateLoadPMu{ p_p, p_mu, p, mu }); while (inout_l < inout_l_end) { _detail::rank_r_update_inner_loop_iter( p_p, p_mu, inout_l, pw, w_stride); ++inout_l; ++pw; } } } }; template VEG_INLINE void rank_r_update_inner_loop(isize n, T* inout_l, T* pw, isize w_stride, T const* p, T const* mu) { RankRUpdateLoopImpl::value>::template fn( n, inout_l, pw, w_stride, p, mu); } template void rank_r_update_clobber_w_impl( // LD ld, T* pw, isize w_stride, T* palpha, Fn r_fn) { static_assert(LD::InnerStrideAtCompileTime == 1, "."); static_assert(!bool(LD::IsRowMajor), "."); isize n = ld.rows(); for (isize j = 0; j < n; ++j) { isize r = r_fn(); isize r_done = 0; if (!(r_done < r)) { continue; } while (true) { isize r_chunk = min2(isize(4), r - r_done); T p_array[4]; T mu_array[4]; T dj = ld(j, j); for (isize k = 0; k < r_chunk; ++k) { auto& p = (+p_array)[k]; auto& mu = (+mu_array)[k]; auto& alpha = palpha[r_done + k]; p = pw[(r_done + k) * w_stride]; T new_dj = dj + (alpha * p) * p; mu = (alpha * p) / new_dj; alpha -= new_dj * (mu * mu); dj = new_dj; } ld(j, j) = dj; isize rem = n - j - 1; using FnType = void (*)(isize, T*, T*, isize, T const*, T const*); FnType fn_table[] = { rank_r_update_inner_loop<1, T>, rank_r_update_inner_loop<2, T>, rank_r_update_inner_loop<3, T>, rank_r_update_inner_loop<4, T>, }; (*fn_table[r_chunk - 1])( // rem, util::matrix_elem_addr(ld, j + 1, j), pw + 1 + r_done * w_stride, w_stride, p_array, mu_array); r_done += r_chunk; if (!(r_done < r)) { break; } } ++pw; } } struct ConstantR { isize r; VEG_INLINE auto operator()() const noexcept -> isize { return r; } }; } // namespace _detail template::Scalar> void rank_1_update_clobber_w(LD&& ld, W&& w, proxsuite::linalg::veg::DoNotDeduce alpha) { _detail::rank_r_update_clobber_w_impl( // util::to_view_dyn(ld), w.data(), 0, proxsuite::linalg::veg::mem::addressof(alpha), _detail::ConstantR{ 1 }); } template::Scalar> void rank_r_update_clobber_inputs(LD&& ld, W&& w, A&& alpha) { isize r = w.cols(); _detail::rank_r_update_clobber_w_impl( // util::to_view_dyn(ld), w.data(), w.outerStride(), alpha.data(), _detail::ConstantR{ r }); } } // namespace dense } // namespace linalg } // namespace proxsuite #endif /* end of include guard PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP */