Program Listing for File blackboard_pywrapper.hpp

Return to documentation for file (include/yasmin/blackboard_pywrapper.hpp)

// Copyright (C) 2025 Miguel Ángel González Santamarta
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

#ifndef YASMIN__BLACKBOARD_PYWRAPPER_HPP_
#define YASMIN__BLACKBOARD_PYWRAPPER_HPP_

#include <pybind11/cast.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <cstdint>
#include <stdexcept>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <vector>

#include "yasmin/blackboard.hpp"
#include "yasmin/types.hpp"

namespace py = pybind11;

namespace yasmin {

class BlackboardPyWrapper {
private:
  Blackboard::SharedPtr blackboard;

  using StringVector = std::vector<std::string>;
  using IntVector = std::vector<std::int64_t>;
  using FloatVector = std::vector<double>;
  using BoolVector = std::vector<bool>;

  using StringDict = std::unordered_map<std::string, std::string>;
  using IntDict = std::unordered_map<std::string, std::int64_t>;
  using FloatDict = std::unordered_map<std::string, double>;
  using BoolDict = std::unordered_map<std::string, bool>;

  static std::vector<uint8_t> py_buffer_to_vector(const py::handle &value) {
    char *data = nullptr;
    py::ssize_t size = 0;

    if (PyBytes_Check(value.ptr()) != 0) {
      if (PyBytes_AsStringAndSize(value.ptr(), &data, &size) != 0) {
        throw py::error_already_set();
      }
    } else if (PyByteArray_Check(value.ptr()) != 0) {
      data = PyByteArray_AsString(value.ptr());
      size = PyByteArray_Size(value.ptr());
    } else {
      throw std::runtime_error("Expected bytes or bytearray object");
    }

    return std::vector<uint8_t>(reinterpret_cast<uint8_t *>(data),
                                reinterpret_cast<uint8_t *>(data) + size);
  }

  template <typename ByteT>
  static py::bytes byte_vector_to_py_bytes(const std::vector<ByteT> &value) {
    return py::bytes(reinterpret_cast<const char *>(value.data()),
                     static_cast<py::ssize_t>(value.size()));
  }

  static bool is_python_bytes_like(const py::handle &value) {
    return PyBytes_Check(value.ptr()) != 0 ||
           PyByteArray_Check(value.ptr()) != 0;
  }

  static bool is_python_int_like(const py::handle &value) {
    return py::isinstance<py::int_>(value) && !py::isinstance<py::bool_>(value);
  }

  static bool is_python_float_like(const py::handle &value) {
    return py::isinstance<py::float_>(value);
  }

  static bool is_python_number_like(const py::handle &value) {
    return is_python_int_like(value) || is_python_float_like(value);
  }

  static bool is_python_sequence_like(const py::handle &value) {
    return py::isinstance<py::list>(value) || py::isinstance<py::tuple>(value);
  }

  template <typename Predicate>
  static bool sequence_matches(const py::sequence &seq, Predicate pred) {
    for (auto item : seq) {
      if (!pred(item)) {
        return false;
      }
    }
    return true;
  }

  template <typename T>
  static std::vector<T> sequence_to_vector(const py::sequence &seq) {
    std::vector<T> result;
    result.reserve(static_cast<std::size_t>(py::len(seq)));

    for (auto item : seq) {
      result.push_back(py::cast<T>(item));
    }

    return result;
  }

  static bool dict_has_only_string_keys(const py::dict &dict) {
    for (auto item : dict) {
      if (!py::isinstance<py::str>(item.first)) {
        return false;
      }
    }
    return true;
  }

  template <typename Predicate>
  static bool dict_values_match(const py::dict &dict, Predicate pred) {
    for (auto item : dict) {
      if (!pred(item.second)) {
        return false;
      }
    }
    return true;
  }

  template <typename T>
  static std::unordered_map<std::string, T>
  dict_to_unordered_map(const py::dict &dict) {
    std::unordered_map<std::string, T> result;
    result.reserve(static_cast<std::size_t>(py::len(dict)));

    for (auto item : dict) {
      result.emplace(py::cast<std::string>(item.first),
                     py::cast<T>(item.second));
    }

    return result;
  }

  template <typename T> static bool is_exact_cpp_type(const std::string &type) {
    return type == demangle_type(typeid(T).name());
  }

public:
  BlackboardPyWrapper() : blackboard(Blackboard::make_shared()) {}

  BlackboardPyWrapper(Blackboard &&other)
      : blackboard(Blackboard::make_shared(std::move(other))) {}

  explicit BlackboardPyWrapper(Blackboard::SharedPtr bb_ptr)
      : blackboard(std::move(bb_ptr)) {}

  void set(const std::string &key, py::object value) {
    if (py::isinstance<py::bool_>(value)) {
      this->blackboard->set<bool>(key, value.cast<bool>());
      return;
    }

    if (is_python_int_like(value)) {
      this->blackboard->set<std::int64_t>(key, value.cast<std::int64_t>());
      return;
    }

    if (py::isinstance<py::float_>(value)) {
      this->blackboard->set<double>(key, value.cast<double>());
      return;
    }

    if (is_python_bytes_like(value)) {
      this->blackboard->set<std::vector<uint8_t>>(key,
                                                  py_buffer_to_vector(value));
      return;
    }

    if (py::isinstance<py::str>(value)) {
      this->blackboard->set<std::string>(key, value.cast<std::string>());
      return;
    }

    if (is_python_sequence_like(value)) {
      py::sequence seq = value.cast<py::sequence>();

      // Empty sequences do not carry enough type information to infer a native
      // element type safely.
      if (py::len(seq) == 0) {
        this->blackboard->set<py::object>(key, value);
        return;
      }

      // Bool must be checked before int because Python bool is also an int.
      if (sequence_matches(seq, [](const py::handle &item) {
            return py::isinstance<py::bool_>(item);
          })) {
        this->blackboard->set<BoolVector>(key, sequence_to_vector<bool>(seq));
        return;
      }

      if (sequence_matches(seq, [](const py::handle &item) {
            return is_python_int_like(item);
          })) {
        this->blackboard->set<IntVector>(key,
                                         sequence_to_vector<std::int64_t>(seq));
        return;
      }

      // Mixed int/float sequences are normalized to double.
      if (sequence_matches(seq, [](const py::handle &item) {
            return is_python_number_like(item);
          })) {
        this->blackboard->set<FloatVector>(key,
                                           sequence_to_vector<double>(seq));
        return;
      }

      if (sequence_matches(seq, [](const py::handle &item) {
            return py::isinstance<py::str>(item);
          })) {
        this->blackboard->set<StringVector>(
            key, sequence_to_vector<std::string>(seq));
        return;
      }

      this->blackboard->set<py::object>(key, value);
      return;
    }

    if (py::isinstance<py::dict>(value)) {
      py::dict dict = value.cast<py::dict>();

      // Only dict[str, T] with a homogeneous value type is converted into a
      // native unordered_map.
      if (py::len(dict) == 0 || !dict_has_only_string_keys(dict)) {
        this->blackboard->set<py::object>(key, value);
        return;
      }

      if (dict_values_match(dict, [](const py::handle &item) {
            return py::isinstance<py::bool_>(item);
          })) {
        this->blackboard->set<BoolDict>(key, dict_to_unordered_map<bool>(dict));
        return;
      }

      if (dict_values_match(dict, [](const py::handle &item) {
            return is_python_int_like(item);
          })) {
        this->blackboard->set<IntDict>(
            key, dict_to_unordered_map<std::int64_t>(dict));
        return;
      }

      // Mixed int/float dictionaries are normalized to double.
      if (dict_values_match(dict, [](const py::handle &item) {
            return is_python_number_like(item);
          })) {
        this->blackboard->set<FloatDict>(key,
                                         dict_to_unordered_map<double>(dict));
        return;
      }

      if (dict_values_match(dict, [](const py::handle &item) {
            return py::isinstance<py::str>(item);
          })) {
        this->blackboard->set<StringDict>(
            key, dict_to_unordered_map<std::string>(dict));
        return;
      }

      this->blackboard->set<py::object>(key, value);
      return;
    }

    this->blackboard->set<py::object>(key, value);
  }

  py::object get(const std::string &key) const {
    const std::string type = this->blackboard->get_type(key);

    if (is_exact_cpp_type<std::vector<uint8_t>>(type)) {
      return byte_vector_to_py_bytes(
          this->blackboard->get<std::vector<uint8_t>>(key));
    }

    if (is_exact_cpp_type<std::vector<unsigned char>>(type)) {
      return byte_vector_to_py_bytes(
          this->blackboard->get<std::vector<unsigned char>>(key));
    }

    if (is_exact_cpp_type<std::vector<char>>(type)) {
      return byte_vector_to_py_bytes(
          this->blackboard->get<std::vector<char>>(key));
    }

    if (is_exact_cpp_type<StringVector>(type)) {
      return py::cast(this->blackboard->get<StringVector>(key));
    }

    if (is_exact_cpp_type<IntVector>(type)) {
      return py::cast(this->blackboard->get<IntVector>(key));
    }

    if (is_exact_cpp_type<std::vector<int>>(type)) {
      return py::cast(this->blackboard->get<std::vector<int>>(key));
    }

    if (is_exact_cpp_type<std::vector<long>>(type)) {
      return py::cast(this->blackboard->get<std::vector<long>>(key));
    }

    if (is_exact_cpp_type<std::vector<long long>>(type)) {
      return py::cast(this->blackboard->get<std::vector<long long>>(key));
    }

    if (is_exact_cpp_type<FloatVector>(type)) {
      return py::cast(this->blackboard->get<FloatVector>(key));
    }

    if (is_exact_cpp_type<std::vector<float>>(type)) {
      return py::cast(this->blackboard->get<std::vector<float>>(key));
    }

    if (is_exact_cpp_type<BoolVector>(type)) {
      return py::cast(this->blackboard->get<BoolVector>(key));
    }

    if (is_exact_cpp_type<StringDict>(type)) {
      return py::cast(this->blackboard->get<StringDict>(key));
    }

    if (is_exact_cpp_type<IntDict>(type)) {
      return py::cast(this->blackboard->get<IntDict>(key));
    }

    if (is_exact_cpp_type<std::unordered_map<std::string, int>>(type)) {
      return py::cast(
          this->blackboard->get<std::unordered_map<std::string, int>>(key));
    }

    if (is_exact_cpp_type<std::unordered_map<std::string, long>>(type)) {
      return py::cast(
          this->blackboard->get<std::unordered_map<std::string, long>>(key));
    }

    if (is_exact_cpp_type<std::unordered_map<std::string, long long>>(type)) {
      return py::cast(
          this->blackboard->get<std::unordered_map<std::string, long long>>(
              key));
    }

    if (is_exact_cpp_type<FloatDict>(type)) {
      return py::cast(this->blackboard->get<FloatDict>(key));
    }

    if (is_exact_cpp_type<std::unordered_map<std::string, float>>(type)) {
      return py::cast(
          this->blackboard->get<std::unordered_map<std::string, float>>(key));
    }

    if (is_exact_cpp_type<BoolDict>(type)) {
      return py::cast(this->blackboard->get<BoolDict>(key));
    }

    if (is_exact_cpp_type<std::string>(type)) {
      return py::cast(this->blackboard->get<std::string>(key));
    }

    if (is_exact_cpp_type<std::int64_t>(type)) {
      return py::cast(this->blackboard->get<std::int64_t>(key));
    }

    if (is_exact_cpp_type<int>(type)) {
      return py::cast(this->blackboard->get<int>(key));
    }

    if (is_exact_cpp_type<long>(type)) {
      return py::cast(this->blackboard->get<long>(key));
    }

    if (is_exact_cpp_type<float>(type)) {
      return py::cast(this->blackboard->get<float>(key));
    }

    if (is_exact_cpp_type<double>(type)) {
      return py::cast(this->blackboard->get<double>(key));
    }

    if (is_exact_cpp_type<bool>(type)) {
      return py::cast(this->blackboard->get<bool>(key));
    }

    if (is_exact_cpp_type<py::object>(type)) {
      return this->blackboard->get<py::object>(key);
    }

    return py::none();
  }

  void remove(const std::string &key) { this->blackboard->remove(key); }

  bool contains(const std::string &key) const {
    return this->blackboard->contains(key);
  }

  int size() const { return this->blackboard->size(); }

  std::vector<std::string> keys() const { return this->blackboard->keys(); }

  py::list values() const {
    py::list result;

    for (const auto &key : this->blackboard->keys()) {
      result.append(this->get(key));
    }

    return result;
  }

  py::list items() const {
    py::list result;

    for (const auto &key : this->blackboard->keys()) {
      result.append(py::make_tuple(key, this->get(key)));
    }

    return result;
  }

  std::string to_string() const { return this->blackboard->to_string(); }

  void set_remappings(const Remappings &remappings) {
    this->blackboard->set_remappings(remappings);
  }

  const Remappings &get_remappings() const {
    return this->blackboard->get_remappings();
  }

  Blackboard::SharedPtr get_cpp_blackboard() const { return this->blackboard; }
};

} // namespace yasmin

#endif // YASMIN__BLACKBOARD_PYWRAPPER_HPP_