Program Listing for File update.hpp

Return to documentation for file (/tmp/ws/src/proxsuite/include/proxsuite/linalg/dense/update.hpp)

//
// Copyright (c) 2022 INRIA
//
#ifndef PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP
#define PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP

#include "proxsuite/linalg/dense/core.hpp"

namespace proxsuite {
namespace linalg {
namespace dense {
namespace _detail {
inline auto
bytes_to_prev_aligned(void* ptr, usize align) noexcept -> isize
{
  using UPtr = std::uintptr_t;

  UPtr mask = align - 1;
  UPtr iptr = UPtr(ptr);
  UPtr aligned_ptr = iptr & ~mask;
  return isize(aligned_ptr - iptr);
}
inline auto
bytes_to_next_aligned(void* ptr, usize align) noexcept -> isize
{
  using UPtr = std::uintptr_t;

  UPtr mask = align - 1;
  UPtr iptr = UPtr(ptr);
  UPtr aligned_ptr = (iptr + mask) & ~mask;
  return isize(aligned_ptr - iptr);
}

template<usize... Is, typename Fn>
VEG_INLINE void
unroll_impl(proxsuite::linalg::veg::meta::index_sequence<Is...> /*unused*/,
            Fn fn)
{
  VEG_EVAL_ALL(fn(Is));
}

template<usize N, typename Fn>
VEG_INLINE void
unroll(Fn fn)
{
  _detail::unroll_impl(proxsuite::linalg::veg::meta::make_index_sequence<N>{},
                       VEG_FWD(fn));
}

template<typename T, usize N>
struct RankUpdateLoadW
{
  _simd::Pack<T, N>* p_wr;
  T const* pw;
  isize w_stride;

  VEG_INLINE void operator()(usize i) const
  {
    p_wr[i] = _simd::Pack<T, N>::load_unaligned(pw + w_stride * isize(i));
  }
};

template<typename T, usize N>
struct RankUpdateUpdateWAndL
{
  _simd::Pack<T, N>* p_wr;
  _simd::Pack<T, N>& p_in_l;
  _simd::Pack<T, N> const* p_p;
  _simd::Pack<T, N> const* p_mu;

  VEG_INLINE void operator()(usize i) const
  {
    p_wr[i] = _simd::Pack<T, N>::fnmadd(p_p[i], p_in_l, p_wr[i]);
    p_in_l = _simd::Pack<T, N>::fmadd(p_mu[i], p_wr[i], p_in_l);
  }
};

template<typename T, usize N>
struct RankUpdateStoreW
{
  _simd::Pack<T, N> const* p_wr;
  T* pw;
  isize w_stride;

  VEG_INLINE void operator()(usize i) const
  {
    p_wr[i].store_unaligned(pw + w_stride * isize(i));
  }
};

template<usize R, typename T, usize N>
VEG_INLINE void
rank_r_update_inner_loop_iter( //
  _simd::Pack<T, N> const* p_p,
  _simd::Pack<T, N> const* p_mu,
  T* inout_l,
  T* pw,
  isize w_stride)
{

  _simd::Pack<T, N> p_wr[R];
  _detail::unroll<R>(RankUpdateLoadW<T, N>{ p_wr, pw, w_stride });
  _simd::Pack<T, N> p_in_l = _simd::Pack<T, N>::load_unaligned(inout_l);
  _detail::unroll<R>(RankUpdateUpdateWAndL<T, N>{ p_wr, p_in_l, p_p, p_mu });
  _detail::unroll<R>(RankUpdateStoreW<T, N>{ p_wr, pw, w_stride });

  p_in_l.store_unaligned(inout_l);
}

template<bool VECTORIZABLE>
struct RankRUpdateLoopImpl;

template<typename T, usize N>
struct RankUpdateLoadPMu
{
  _simd::Pack<T, N>* p_p;
  _simd::Pack<T, N>* p_mu;
  T const* p;
  T const* mu;
  VEG_INLINE void operator()(usize i) const
  {
    p_p[i] = _simd::Pack<T, N>::broadcast(p[i]);
    p_mu[i] = _simd::Pack<T, N>::broadcast(mu[i]);
  }
};

template<>
struct RankRUpdateLoopImpl<false>
{
  template<usize R, typename T>
  VEG_INLINE static void fn(isize n,
                            T* inout_l,
                            T* pw,
                            isize w_stride,
                            T const* p,
                            T const* mu) noexcept
  {
    using Pack_ = _simd::Pack<T, 1>;
    Pack_ p_p[R];
    Pack_ p_mu[R];

    _detail::unroll<R>(RankUpdateLoadPMu<T, 1>{ p_p, p_mu, p, mu });

    auto inout_l_finish = inout_l + n;
    while (inout_l < inout_l_finish) {
      _detail::rank_r_update_inner_loop_iter<R>(
        p_p, p_mu, inout_l, pw, w_stride);
      ++inout_l;
      ++pw;
    }
  }
};

template<>
struct RankRUpdateLoopImpl<true>
{
  template<usize R, typename T>
  VEG_INLINE static void fn(isize n,
                            T* inout_l,
                            T* pw,
                            isize w_stride,
                            T const* p,
                            T const* mu) noexcept
  {

    // best perf if beginning of each pw is aligned
    // should be enforced by the Ldlt class

    using Info = _simd::NativePackInfo<T>;
    constexpr usize N = Info::N;
    auto inout_l_vectorized_end = inout_l + usize(n) / N * N;
    auto inout_l_end = inout_l + usize(n);

    {
      using Pack = _simd::NativePack<T>;
      Pack p_p[R];
      Pack p_mu[R];

      _detail::unroll<R>(RankUpdateLoadPMu<T, N>{ p_p, p_mu, p, mu });

      while (inout_l < inout_l_vectorized_end) {
        _detail::rank_r_update_inner_loop_iter<R>(
          p_p, p_mu, inout_l, pw, w_stride);
        inout_l += N;
        pw += N;
      }
    }
    {
      using Pack_ = _simd::Pack<T, 1>;
      Pack_ p_p[R];
      Pack_ p_mu[R];

      _detail::unroll<R>(RankUpdateLoadPMu<T, 1>{ p_p, p_mu, p, mu });

      while (inout_l < inout_l_end) {
        _detail::rank_r_update_inner_loop_iter<R>(
          p_p, p_mu, inout_l, pw, w_stride);
        ++inout_l;
        ++pw;
      }
    }
  }
};

template<usize R, typename T>
VEG_INLINE void
rank_r_update_inner_loop(isize n,
                         T* inout_l,
                         T* pw,
                         isize w_stride,
                         T const* p,
                         T const* mu)
{
  RankRUpdateLoopImpl<should_vectorize<T>::value>::template fn<R>(
    n, inout_l, pw, w_stride, p, mu);
}

template<typename LD, typename T, typename Fn>
void
rank_r_update_clobber_w_impl( //
  LD ld,
  T* pw,
  isize w_stride,
  T* palpha,
  Fn r_fn)
{
  static_assert(LD::InnerStrideAtCompileTime == 1, ".");
  static_assert(!bool(LD::IsRowMajor), ".");

  isize n = ld.rows();

  for (isize j = 0; j < n; ++j) {
    isize r = r_fn();

    isize r_done = 0;
    if (!(r_done < r)) {
      continue;
    }

    while (true) {
      isize r_chunk = min2(isize(4), r - r_done);

      T p_array[4];
      T mu_array[4];

      T dj = ld(j, j);
      for (isize k = 0; k < r_chunk; ++k) {
        auto& p = (+p_array)[k];
        auto& mu = (+mu_array)[k];
        auto& alpha = palpha[r_done + k];

        p = pw[(r_done + k) * w_stride];
        T new_dj = dj + (alpha * p) * p;
        mu = (alpha * p) / new_dj;
        alpha -= new_dj * (mu * mu);

        dj = new_dj;
      }
      ld(j, j) = dj;

      isize rem = n - j - 1;

      using FnType = void (*)(isize, T*, T*, isize, T const*, T const*);
      FnType fn_table[] = {
        rank_r_update_inner_loop<1, T>,
        rank_r_update_inner_loop<2, T>,
        rank_r_update_inner_loop<3, T>,
        rank_r_update_inner_loop<4, T>,
      };

      (*fn_table[r_chunk - 1])( //
        rem,
        util::matrix_elem_addr(ld, j + 1, j),
        pw + 1 + r_done * w_stride,
        w_stride,
        p_array,
        mu_array);

      r_done += r_chunk;
      if (!(r_done < r)) {
        break;
      }
    }
    ++pw;
  }
}
struct ConstantR
{
  isize r;
  VEG_INLINE auto operator()() const noexcept -> isize { return r; }
};
} // namespace _detail

template<typename LD,
         typename W,
         typename T = typename proxsuite::linalg::veg::uncvref_t<LD>::Scalar>
void
rank_1_update_clobber_w(LD&& ld,
                        W&& w,
                        proxsuite::linalg::veg::DoNotDeduce<T> alpha)
{
  _detail::rank_r_update_clobber_w_impl( //
    util::to_view_dyn(ld),
    w.data(),
    0,
    proxsuite::linalg::veg::mem::addressof(alpha),
    _detail::ConstantR{ 1 });
}

template<typename LD,
         typename W,
         typename A,
         typename T = typename proxsuite::linalg::veg::uncvref_t<LD>::Scalar>
void
rank_r_update_clobber_inputs(LD&& ld, W&& w, A&& alpha)
{
  isize r = w.cols();
  _detail::rank_r_update_clobber_w_impl( //
    util::to_view_dyn(ld),
    w.data(),
    w.outerStride(),
    alpha.data(),
    _detail::ConstantR{ r });
}
} // namespace dense
} // namespace linalg
} // namespace proxsuite

#endif /* end of include guard PROXSUITE_LINALG_DENSE_LDLT_UPDATE_HPP */