Program Listing for File update.hpp
↰ Return to documentation for file (/tmp/ws/src/proxsuite/include/proxsuite/linalg/sparse/update.hpp
)
//
// Copyright (c) 2022 INRIA
//
#ifndef PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP
#define PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP
#include "proxsuite/linalg/sparse/core.hpp"
#include <proxsuite/linalg/veg/tuple.hpp>
#include <algorithm>
namespace proxsuite {
namespace linalg {
namespace sparse {
/*
calcule mémoire nécessaire pour la fonction merge_second_col_into_first
*/
template<typename I>
auto
merge_second_col_into_first_req(proxsuite::linalg::veg::Tag<I> /*tag*/,
isize second_size) noexcept
-> proxsuite::linalg::veg::dynstack::StackReq
{
return {
second_size * isize{ sizeof(I) },
alignof(I),
};
}
template<typename T, typename I>
auto
merge_second_col_into_first( //
I* difference,
T* first_values,
I* first_ptr,
PROXSUITE_MAYBE_UNUSED isize first_full_len,
isize first_initial_len,
Slice<I> second,
proxsuite::linalg::veg::DoNotDeduce<I> ignore_threshold_inclusive,
bool move_values,
DynStackMut stack) noexcept(false)
-> proxsuite::linalg::veg::Tuple<SliceMut<T>, SliceMut<I>, SliceMut<I>>
{
VEG_CHECK_CONCEPT(trivially_copyable<I>);
VEG_CHECK_CONCEPT(trivially_copyable<T>);
if (second.len() == 0) {
return {
proxsuite::linalg::veg::tuplify,
{ unsafe, from_raw_parts, first_values, first_initial_len },
{ unsafe, from_raw_parts, first_ptr, first_initial_len },
{ unsafe, from_raw_parts, difference, 0 },
};
}
I const* second_ptr = second.ptr();
usize second_len = usize(second.len());
usize index_second = 0;
for (; index_second < second_len; ++index_second) {
if (second_ptr[index_second] > ignore_threshold_inclusive) {
break;
}
}
auto ufirst_initial_len = usize(first_initial_len);
second_ptr += index_second;
second_len -= index_second;
index_second = 0;
proxsuite::linalg::veg::Tag<I> tag{};
auto _ins_pos = stack.make_new_for_overwrite(tag, isize(second_len));
I* insert_pos_ptr = _ins_pos.ptr_mut();
usize insert_count = 0;
for (usize index_first = 0; index_first < ufirst_initial_len; ++index_first) {
I current_first = first_ptr[index_first];
while (true) {
if (!(index_second < second_len)) {
break;
}
I current_second = second_ptr[index_second];
if (!(current_second < current_first)) {
break;
}
insert_pos_ptr[insert_count] = I(index_first);
difference[insert_count] = current_second;
++insert_count;
++index_second;
}
if (index_second == second_len) {
break;
}
if (second_ptr[index_second] == current_first) {
++index_second;
}
}
usize remaining_insert_count = insert_count;
usize first_new_len =
ufirst_initial_len + insert_count + (second_len - index_second);
VEG_ASSERT(usize(first_full_len) >= first_new_len);
usize append_count = second_len - index_second;
std::memmove( //
difference + insert_count,
second_ptr + index_second,
append_count * sizeof(I));
std::memmove( //
first_ptr + (ufirst_initial_len + insert_count),
second_ptr + index_second,
append_count * sizeof(I));
if (move_values) {
for (usize i = 0; i < append_count; ++i) {
first_values[i + ufirst_initial_len + insert_count] = 0;
}
}
while (remaining_insert_count != 0) {
usize old_insert_pos = usize(insert_pos_ptr[remaining_insert_count - 1]);
usize range_size =
(remaining_insert_count == insert_count)
? ufirst_initial_len - old_insert_pos
: usize(insert_pos_ptr[remaining_insert_count]) - old_insert_pos;
usize old_pos = old_insert_pos;
usize new_pos = old_pos + remaining_insert_count;
std::memmove( //
first_ptr + new_pos,
first_ptr + old_pos,
range_size * sizeof(I));
if (move_values) {
std::memmove( //
first_values + new_pos,
first_values + old_pos,
range_size * sizeof(T));
first_values[new_pos - 1] = 0;
}
first_ptr[new_pos - 1] = difference[remaining_insert_count - 1];
--remaining_insert_count;
}
return {
proxsuite::linalg::veg::tuplify,
{ unsafe, from_raw_parts, first_values, isize(first_new_len) },
{ unsafe, from_raw_parts, first_ptr, isize(first_new_len) },
{ unsafe, from_raw_parts, difference, isize(insert_count + append_count) },
};
}
template<typename T, typename I>
auto
rank1_update_req( //
proxsuite::linalg::veg::Tag<T> /*tag*/,
proxsuite::linalg::veg::Tag<I> /*tag*/,
isize n,
bool id_perm,
isize col_nnz) noexcept -> proxsuite::linalg::veg::dynstack::StackReq
{
using proxsuite::linalg::veg::dynstack::StackReq;
StackReq permuted_indices = { id_perm ? 0 : (col_nnz * isize{ sizeof(I) }),
isize{ alignof(I) } };
StackReq difference = { n * isize{ sizeof(I) }, isize{ alignof(I) } };
difference = difference & difference;
StackReq merge = sparse::merge_second_col_into_first_req(
proxsuite::linalg::veg::Tag<I>{}, n);
StackReq numerical_workspace = { n * isize{ sizeof(T) },
isize{ alignof(T) } };
return permuted_indices & ((difference & merge) | numerical_workspace);
}
template<typename T, typename I>
auto
rank1_update(MatMut<T, I> ld,
I* etree,
I const* perm_inv,
VecRef<T, I> w,
proxsuite::linalg::veg::DoNotDeduce<T> alpha,
DynStackMut stack) noexcept(false) -> MatMut<T, I>
{
VEG_ASSERT(!ld.is_compressed());
if (w.nnz() == 0) {
return ld;
}
proxsuite::linalg::veg::Tag<I> tag;
usize n = usize(ld.ncols());
bool id_perm = perm_inv == nullptr;
auto _w_permuted_indices =
stack.make_new_for_overwrite(tag, id_perm ? isize(0) : w.nnz());
auto w_permuted_indices =
id_perm ? w.row_indices() : _w_permuted_indices.ptr();
if (!id_perm) {
I* pw_permuted_indices = _w_permuted_indices.ptr_mut();
for (usize k = 0; k < usize(w.nnz()); ++k) {
usize i = util::zero_extend(w.row_indices()[k]);
pw_permuted_indices[k] = perm_inv[i];
}
std::sort(pw_permuted_indices, pw_permuted_indices + w.nnz());
}
auto sx = util::sign_extend;
auto zx = util::zero_extend;
// symbolic update
{
usize current_col = zx(w_permuted_indices[0]);
auto _difference =
stack.make_new_for_overwrite(tag, isize(n - current_col));
auto _difference_backup =
stack.make_new_for_overwrite(tag, isize(n - current_col));
auto merge_col = w_permuted_indices;
isize merge_col_len = w.nnz();
I* difference = _difference.ptr_mut();
while (true) {
usize old_parent = sx(etree[isize(current_col)]);
usize current_ptr_idx = zx(ld.col_ptrs()[isize(current_col)]);
usize next_ptr_idx = zx(ld.col_ptrs()[isize(current_col) + 1]);
VEG_BIND(auto,
(_, new_current_col, computed_difference),
sparse::merge_second_col_into_first(
difference,
ld.values_mut() + (current_ptr_idx + 1),
ld.row_indices_mut() + (current_ptr_idx + 1),
isize(next_ptr_idx - current_ptr_idx),
isize(zx(ld.nnz_per_col()[isize(current_col)])) - 1,
proxsuite::linalg::veg::Slice<I>{
unsafe, from_raw_parts, merge_col, merge_col_len },
I(current_col),
true,
stack));
(void)_;
ld._set_nnz(ld.nnz() + new_current_col.len() + 1 -
isize(ld.nnz_per_col()[isize(current_col)]));
ld.nnz_per_col_mut()[isize(current_col)] = I(new_current_col.len() + 1);
usize new_parent =
(new_current_col.len() == 0) ? usize(-1) : sx(new_current_col[0]);
if (new_parent == usize(-1)) {
break;
}
if (new_parent == old_parent) {
merge_col = computed_difference.ptr();
merge_col_len = computed_difference.len();
difference = _difference_backup.ptr_mut();
} else {
merge_col = new_current_col.ptr();
merge_col_len = new_current_col.len();
difference = _difference.ptr_mut();
etree[isize(current_col)] = I(new_parent);
}
current_col = new_parent;
}
}
// numerical update
{
usize first_col = zx(w_permuted_indices[0]);
auto _work =
stack.make_new_for_overwrite(proxsuite::linalg::veg::Tag<T>{}, isize(n));
T* pwork = _work.ptr_mut();
for (usize col = first_col; col != usize(-1); col = sx(etree[isize(col)])) {
pwork[col] = 0;
}
for (usize p = 0; p < usize(w.nnz()); ++p) {
pwork[id_perm ? zx(w.row_indices()[isize(p)])
: zx(perm_inv[w.row_indices()[isize(p)]])] =
w.values()[isize(p)];
}
I const* pldi = ld.row_indices();
T* pldx = ld.values_mut();
for (usize col = first_col; col != usize(-1); col = sx(etree[isize(col)])) {
auto col_start = ld.col_start(col);
auto col_end = ld.col_end(col);
T w0 = pwork[col];
T old_d = pldx[col_start];
T new_d = old_d + alpha * w0 * w0;
T beta = alpha * w0 / new_d;
alpha = alpha - new_d * beta * beta;
pldx[col_start] = new_d;
pwork[col] -= w0;
for (usize p = col_start + 1; p < col_end; ++p) {
usize i = util::zero_extend(pldi[p]);
T tmp = pldx[p];
pwork[i] = pwork[i] - w0 * tmp;
pldx[p] = tmp + beta * pwork[i];
}
}
}
return ld;
}
} // namespace sparse
} // namespace linalg
} // namespace proxsuite
#endif /* end of include guard PROXSUITE_LINALG_SPARSE_LDLT_UPDATE_HPP */