// Copyright (c) 2022 INRIA

#include "proxsuite/linalg/sparse/update.hpp"
#include <algorithm>

namespace proxsuite {
namespace linalg {
namespace sparse {

template<typename T, typename I>
delete_row_req( //
  proxsuite::linalg::veg::Tag<T> /*tag*/,
  proxsuite::linalg::veg::Tag<I> /*tag*/,
  isize n,
  isize max_nnz) noexcept -> proxsuite::linalg::veg::dynstack::StackReq
  return sparse::rank1_update_req(proxsuite::linalg::veg::Tag<T>{},

template<typename T, typename I>
delete_row(MatMut<T, I> ld,
           I* etree,
           I const* perm_inv,
           isize pos,
           DynStackMut stack) noexcept(false) -> MatMut<T, I>
  // step 1: delete row k from each column

  // we're actually deleting perm_inv[k], so that k is deleted in the permuted
  // matrix
  usize permuted_pos =
    perm_inv == nullptr ? usize(pos) : util::zero_extend(perm_inv[pos]);

  auto petree = etree;
  I* pldi = ld.row_indices_mut();
  T* pldx = ld.values_mut();
  I* pldnz = ld.nnz_per_col_mut();

  for (usize j = 0; j < permuted_pos; ++j) {
    auto col_start = ld.col_start(j) + 1;
    auto col_end = ld.col_end(j);
    // search for the first row in column j greater than or equal to k
    auto it =
      std::lower_bound(pldi + col_start, pldi + col_end, I(permuted_pos));

    // if an element was found, and it is equal to k
    if ((it != (pldi + col_end)) && *it == I(permuted_pos)) {
      usize it_pos = usize(it - (pldi + col_start));
      usize count = (col_end - col_start - it_pos);
      // shift all the row indices back by one position
      // to delete row k
      std::memmove(it, it + 1, count * sizeof(I));
      T* itx = pldx + col_start + it_pos;

      // shift all the values back by one position
      std::memmove(itx, itx + 1, count * sizeof(T));

      // decrement the non zero count
      ld._set_nnz(ld.nnz() - 1);

      // adjust the parent of j in the elimination tree if necessary
      if (petree[j] == I(permuted_pos)) {
        VEG_ASSERT(it_pos == 0);
        if (pldnz[j] > 1) {
          petree[j] = *it;
        } else {
          petree[j] = I(-1);

  // step 2: set d_kk = 1
  T d_old = ld.values()[ld.col_start(permuted_pos)];
  ld.values_mut()[ld.col_start(permuted_pos)] = 1;

  // step 3: perform rank update
  isize len = isize(util::zero_extend(ld.nnz_per_col()[permuted_pos])) - 1;
  ld = sparse::rank1_update<T, I>( //
    static_cast<I const*>(nullptr),
    VecRef<T, I>{
      pldi + ld.col_start(permuted_pos) + 1,
      pldx + ld.col_start(permuted_pos) + 1,
  // step 4: delete col k_
  ld.nnz_per_col_mut()[permuted_pos] = 1;
  petree[permuted_pos] = I(-1);
  return ld;
template<typename T, typename I>
add_row_req( //
  proxsuite::linalg::veg::Tag<T> /*tag*/,
  proxsuite::linalg::veg::Tag<I> /*tag*/,
  isize n,
  bool id_perm,
  isize nnz,
  isize max_nnz) noexcept -> proxsuite::linalg::veg::dynstack::StackReq
  using proxsuite::linalg::veg::dynstack::StackReq;
  auto numerical_work = StackReq{ n * isize{ sizeof(T) }, isize{ alignof(T) } };
  auto permuted_indices =
    StackReq{ (id_perm ? 0 : nnz) * isize{ sizeof(I) }, isize{ alignof(I) } };
  auto pattern_diff = StackReq{ n * isize{ sizeof(I) }, isize{ alignof(I) } };
  auto merge =
    merge_second_col_into_first_req(proxsuite::linalg::veg::Tag<I>{}, n);
  auto update = sparse::rank1_update_req(proxsuite::linalg::veg::Tag<T>{},

  auto req = numerical_work;
  req = req & permuted_indices;
  req = req & pattern_diff;
  req = req & merge;
  req = req | update;

  return req;
template<typename T, typename I>
add_row(MatMut<T, I> ld,
        I* etree,
        I const* perm_inv,
        isize pos,
        VecRef<T, I> new_col,
        proxsuite::linalg::veg::DoNotDeduce<T> diag_element,
        DynStackMut stack) noexcept(false) -> MatMut<T, I>
  bool id_perm = perm_inv == nullptr;
  auto zx = util::zero_extend;

  I* pldp = ld.col_ptrs_mut();
  I* pldnz = ld.nnz_per_col_mut();
  I* pldi = ld.row_indices_mut();
  T* pldx = ld.values_mut();

  // actually inserting in the position perm_inv[k] so that row k is added in
  // the permuted matrix
  usize permuted_pos = id_perm ? usize(pos) : zx(perm_inv[pos]);
  VEG_ASSERT(pldnz[permuted_pos] == 1);

    // allocate workspace for numerical step, storage for the k-th row and k-th
    // column of the new matrix
    auto _lx2_storage = stack.make_new_for_overwrite(
      proxsuite::linalg::veg::Tag<T>{}, ld.nrows());
    auto plx2_storage = _lx2_storage.ptr_mut();

    // allocate workspace for permuted row indices of the new column if
    // necessary
    auto _new_col_permuted_indices = stack.make_new_for_overwrite(
      proxsuite::linalg::veg::Tag<I>{}, id_perm ? isize(0) : new_col.nnz());

    auto new_col_permuted_indices =
      id_perm ? new_col.row_indices() : _new_col_permuted_indices.ptr();

    // copy and sort permuted row indices
    if (!id_perm) {
      I* pnew_col_permuted_indices = _new_col_permuted_indices.ptr_mut();
      for (usize k = 0; k < usize(new_col.nnz()); ++k) {
        usize i = zx(new_col.row_indices()[k]);
        pnew_col_permuted_indices[k] = perm_inv[i];
                pnew_col_permuted_indices + new_col.nnz());

    // allocate workspace for non-zero pattern of k-th row
    auto _l12_nnz_pattern = stack.make_new_for_overwrite(
      proxsuite::linalg::veg::Tag<I>{}, isize(permuted_pos));
    auto _difference = stack.make_new_for_overwrite(
      proxsuite::linalg::veg::Tag<I>{}, ld.nrows() - isize(permuted_pos));
    auto pdifference = _difference.ptr_mut();

    auto pl12_nnz_pattern = _l12_nnz_pattern.ptr_mut();
    usize l12_nnz_pattern_count = 0;

    // the non-zero pattern is the set of columns reachable from the non-zero
    // pattern of the added column through graph of L_{1..k,1..k}
    // instead of graph traversal, we can use the k-th elimination subtree as we
    // did in the initial factorization step

    // for each row in the added column
      auto _visited = stack.make_new(proxsuite::linalg::veg::Tag<bool>{},
      bool* visited = _visited.ptr_mut();
      for (usize p = 0; p < usize(new_col.nnz()); ++p) {
        auto j = zx(new_col_permuted_indices[p]);
        if (j >= permuted_pos) {

        // add the ancestors of the corresponding column
        // ancestors are not sorted, but they are added in topological order,
        // which suffices for the triangular solve
        while (true) {
          if (visited[j]) {
          visited[j] = true;
          pl12_nnz_pattern[l12_nnz_pattern_count] = I(j);

          j = util::sign_extend(etree[j]);
          if (j == usize(-1) || j >= permuted_pos || visited[j]) {
    std::sort(pl12_nnz_pattern, pl12_nnz_pattern + l12_nnz_pattern_count);

    // zero the elements in the non-zero pattern of the solution (new k-th row)
    for (usize p = 0; p < l12_nnz_pattern_count; ++p) {
      plx2_storage[zx(pl12_nnz_pattern[p])] = 0;

    // insert the rhs of the k-th row triangular system in the top part of the
    // storage, and the bottom part of the added column in the bottom part of
    // the storage
    for (usize p = 0; p < usize(new_col.nnz()); ++p) {
      auto j = zx(new_col.row_indices()[p]);
      auto permuted_j = id_perm ? j : zx(perm_inv[j]);
      plx2_storage[permuted_j] = new_col.values()[p];

      // add the row indices of the bottom part of the added column, to the
      // k-th column of L
      if (permuted_j > permuted_pos) {
        usize nz = zx(pldnz[permuted_pos]);
        VEG_ASSERT(nz < (zx(pldp[permuted_pos + 1]) - zx(pldp[permuted_pos])));
        pldi[zx(pldp[permuted_pos]) + nz] = I(permuted_j);
        ld._set_nnz(ld.nnz() + 1);
    // sort the added row indices
    std::sort(pldi + zx(pldp[permuted_pos]) + 1,
              pldi + zx(pldp[permuted_pos]) + zx(pldnz[permuted_pos]));

    // TODO: fuse loops?

    for (usize p = 0; p < l12_nnz_pattern_count; ++p) {
      usize j = zx(pl12_nnz_pattern[p]);
      auto col_start = ld.col_start(j);
      auto col_end = ld.col_end(j);

      // update the pattern of the k-th column of L, with that of the bottom
      // part of the j-th column of L, ignoring the elements less than or equal
      // to k
               (_, new_current_col, computed_difference),
                 pldi + (zx(pldp[permuted_pos]) + 1),
                 isize(zx(pldp[permuted_pos + 1]) - zx(pldp[permuted_pos])) - 1,
                 pldnz[permuted_pos] - 1,
                   pldi + (zx(pldp[j]) + 1),
                   isize(zx(pldnz[j])) - 1,

      // update column and global non-zero count
      pldnz[permuted_pos] += I(computed_difference.len());
      ld._set_nnz(ld.nnz() + computed_difference.len());

      for (usize q = 0; q < usize(computed_difference.len()); ++q) {
        plx2_storage[zx(computed_difference.ptr()[q])] = 0;

      // perform triangular solve and matrix vector product simultaneously
      auto const xj = plx2_storage[j];
      for (usize q = col_start + 1; q < col_end; ++q) {
        auto i = zx(pldi[q]);
        plx2_storage[i] -= pldx[q] * xj;

    // insert the k-th row into L
    for (usize p = 0; p < l12_nnz_pattern_count; ++p) {
      // for each column in the non-zero pattern of the k-th row

      usize j = zx(pl12_nnz_pattern[p]);
      auto col_start = ld.col_start(j);
      auto col_end = ld.col_end(j);
      T d = pldx[col_start];
      T l12_elem = plx2_storage[j];
      diag_element -= l12_elem * l12_elem / d;

      // check that we have enough space to insert one element
      VEG_ASSERT(zx(pldnz[j]) < (zx(pldp[j + 1]) - zx(pldp[j])));

      // find the first element greater than k
      auto it =
        std::lower_bound(pldi + col_start, pldi + col_end, I(permuted_pos));

      // if it is the first element, update the elimination tree so that k is
      // the new parent of column j
      if (it == (pldi + col_start + 1)) {
        etree[j] = I(permuted_pos);

      // shift the row indices  up by one position to provide enough space for
      // the new element
      std::memmove( //
        it + 1,
        usize((pldi + col_end) - it) * sizeof(I));


      // shift the values  up by one position to provide enough space for the
      // new element
      std::memmove( //
        pldx + (it - pldi) + 1,
        pldx + (it - pldi),
        usize((pldi + col_end) - it) * sizeof(T));

      // insert the new row index k
      *it = I(permuted_pos);
      // insert the new corresponding value
      *(pldx + (it - pldi)) = l12_elem / d;
      // update the non-zero count
      ld._set_nnz(ld.nnz() + 1);

    // insert the k-th column of L
      usize col_start = ld.col_start(permuted_pos);
      usize col_end = ld.col_end(permuted_pos);
      pldx[col_start] = diag_element;
      for (usize p = col_start + 1; p < col_end; ++p) {
        pldx[p] = plx2_storage[zx(pldi[p])] / diag_element;

  // set the parent of the k-th column of L
  if (pldnz[permuted_pos] > 1) {
    etree[permuted_pos] = pldi[ld.col_start(permuted_pos) + 1];

  isize len = isize(util::zero_extend(ld.nnz_per_col()[permuted_pos])) - 1;
  // perform the rank update with the newly added column
  ld = sparse::rank1_update<T, I>(ld,
                                  static_cast<I const*>(nullptr),
                                  VecRef<T, I>{
                                    pldi + ld.col_start(permuted_pos) + 1,
                                    pldx + ld.col_start(permuted_pos) + 1,

  return ld;
} // namespace sparse
} // namespace linalg
} // namespace proxsuite

#endif /* end of include guard PROXSUITE_LINALG_SPARSE_LDLT_ROWMOD_HPP */