Program Listing for File srdf.hxx

Return to documentation for file (include/pinocchio/parsers/srdf.hxx)

//
// Copyright (c) 2017-2022 CNRS INRIA
//

#ifndef __pinocchio_parser_srdf_hxx__
#define __pinocchio_parser_srdf_hxx__

#include "pinocchio/parsers/srdf.hpp"

#include "pinocchio/multibody/model.hpp"
#include "pinocchio/multibody/geometry.hpp"
#include "pinocchio/algorithm/joint-configuration.hpp"
#include <iostream>

// Read XML file with boost
#include <boost/property_tree/xml_parser.hpp>
#include <boost/property_tree/ptree.hpp>
#include <fstream>
#include <sstream>
#include <boost/foreach.hpp>

namespace pinocchio
{
  namespace srdf
  {
    namespace details
    {
      template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
      void removeCollisionPairs(
        const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
        GeometryModel & geom_model,
        std::istream & stream,
        const bool verbose = false)
      {
        typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;

        // Read xml stream
        using boost::property_tree::ptree;
        ptree pt;
        read_xml(stream, pt);

        // Iterate over collision pairs
        BOOST_FOREACH (const ptree::value_type & v, pt.get_child("robot"))
        {
          if (v.first == "disable_collisions")
          {
            const std::string link1 = v.second.get<std::string>("<xmlattr>.link1");
            const std::string link2 = v.second.get<std::string>("<xmlattr>.link2");

            // Check first if the two bodies exist in model
            if (!model.existBodyName(link1) || !model.existBodyName(link2))
            {
              if (verbose)
                std::cout << "It seems that " << link1 << " or " << link2
                          << " do not exist in model. Skip." << std::endl;
              continue;
            }

            const typename Model::FrameIndex frame_id1 = model.getBodyId(link1);
            const typename Model::FrameIndex frame_id2 = model.getBodyId(link2);

            // Malformed SRDF
            if (frame_id1 == frame_id2)
            {
              if (verbose)
                std::cout << "Cannot disable collision between " << link1 << " and " << link2
                          << std::endl;
              continue;
            }

            typedef GeometryModel::CollisionPairVector CollisionPairVector;
            bool didRemove = false;
            for (CollisionPairVector::const_iterator cp_iterator =
                   geom_model.collisionPairs.begin();
                 cp_iterator != geom_model.collisionPairs.end();)
            {
              const CollisionPair & cp = *cp_iterator;
              const PairIndex cp_index = geom_model.findCollisionPair(cp);
              const bool remove =
                ((geom_model.geometryObjects[cp.first].parentFrame == frame_id1)
                 && (geom_model.geometryObjects[cp.second].parentFrame == frame_id2))
                || ((geom_model.geometryObjects[cp.second].parentFrame == frame_id1)
                    && (geom_model.geometryObjects[cp.first].parentFrame == frame_id2));

              if (remove)
              {
                geom_model.removeCollisionPair(cp);
                cp_iterator = geom_model.collisionPairs.begin() + (long)cp_index;
                didRemove = true;
              }
              else
              {
                ++cp_iterator;
              }
            }
            if (didRemove && verbose)
              std::cout << "Remove collision pair (" << link1 << "," << link2 << ")" << std::endl;
          }
        } // BOOST_FOREACH
      }
    } // namespace details

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    void removeCollisionPairs(
      const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      GeometryModel & geom_model,
      const std::string & filename,
      const bool verbose)
    {
      // Check extension
      const std::string extension = filename.substr(filename.find_last_of('.') + 1);
      if (extension != "srdf")
      {
        const std::string exception_message(filename + " does not have the right extension.");
        throw std::invalid_argument(exception_message);
      }

      // Open file
      std::ifstream srdf_stream(filename.c_str());
      if (!srdf_stream.is_open())
      {
        const std::string exception_message(filename + " does not seem to be a valid file.");
        throw std::invalid_argument(exception_message);
      }

      details::removeCollisionPairs(model, geom_model, srdf_stream, verbose);
    }

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    void removeCollisionPairsFromXML(
      const ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      GeometryModel & geom_model,
      const std::string & xmlString,
      const bool verbose)
    {
      std::istringstream srdf_stream(xmlString);
      details::removeCollisionPairs(model, geom_model, srdf_stream, verbose);
    }

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    bool loadRotorParameters(
      ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      const std::string & filename,
      const bool verbose)
    {
      typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
      typedef typename Model::JointModel JointModel;

      // Check extension
      const std::string extension = filename.substr(filename.find_last_of('.') + 1);
      if (extension != "srdf")
      {
        const std::string exception_message(filename + " does not have the right extension.");
        throw std::invalid_argument(exception_message);
      }

      // Open file
      std::ifstream srdf_stream(filename.c_str());
      if (!srdf_stream.is_open())
      {
        const std::string exception_message(filename + " does not seem to be a valid file.");
        throw std::invalid_argument(exception_message);
      }

      // Read xml stream
      using boost::property_tree::ptree;
      ptree pt;
      read_xml(srdf_stream, pt);

      // Iterate over all tags directly children of robot
      BOOST_FOREACH (const ptree::value_type & v, pt.get_child("robot"))
      {
        // if we encounter a tag rotor_params
        if (v.first == "rotor_params")
        {
          // Iterate over all the joint tags
          BOOST_FOREACH (const ptree::value_type & joint, v.second)
          {
            if (joint.first == "joint")
            {
              const std::string joint_name = joint.second.get<std::string>("<xmlattr>.name");
              const Scalar rotor_inertia = (Scalar)joint.second.get<double>("<xmlattr>.mass");
              const Scalar rotor_gear_ratio =
                (Scalar)joint.second.get<double>("<xmlattr>.gear_ratio");
              if (verbose)
              {
                std::cout << "(" << joint_name << " , " << rotor_inertia << " , "
                          << rotor_gear_ratio << ")" << std::endl;
              }
              // Search in model the joint and its config id
              typename Model::JointIndex joint_id = model.getJointId(joint_name);

              if (joint_id != model.joints.size()) // != model.njoints
              {
                const JointModel & joint = model.joints[joint_id];
                PINOCCHIO_CHECK_INPUT_ARGUMENT(joint.nv() == 1);

                model.armature[joint.idx_v()] +=
                  rotor_inertia * rotor_gear_ratio * rotor_gear_ratio;
                model.rotorInertia(joint.idx_v()) = rotor_inertia;
                model.rotorGearRatio(joint.idx_v()) = rotor_gear_ratio; // joint with 1 dof
              }
              else
              {
                if (verbose)
                  std::cout << "The Joint " << joint_name << " was not found in model" << std::endl;
              }
            }
          }
          return true;
        }
      }
      PINOCCHIO_CHECK_INPUT_ARGUMENT(false, "no rotor params found in the SRDF file");
      return false; // warning : uninitialized vector is returned
    }

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    struct LoadReferenceConfigurationStep
    : fusion::JointUnaryVisitorBase<
        LoadReferenceConfigurationStep<Scalar, Options, JointCollectionTpl>>
    {
      typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
      typedef typename Model::ConfigVectorType ConfigVectorType;

      typedef boost::fusion::
        vector<const std::string &, const ConfigVectorType &, ConfigVectorType &>
          ArgsType;

      template<typename JointModel>
      static void algo(
        const JointModelBase<JointModel> & joint,
        const std::string & joint_name,
        const ConfigVectorType & fromXML,
        ConfigVectorType & config)
      {
        algo_impl(joint.derived(), joint_name, fromXML, config);
      }

    private:
      template<int axis>
      static void algo_impl(
        const JointModelRevoluteUnboundedTpl<Scalar, Options, axis> & joint,
        const std::string & joint_name,
        const ConfigVectorType & fromXML,
        ConfigVectorType & config)
      {
        typedef JointModelRevoluteUnboundedTpl<Scalar, Options, axis> JointModelRUB;
        PINOCCHIO_STATIC_ASSERT(
          JointModelRUB::NQ == 2, JOINT_MODEL_REVOLUTE_SHOULD_HAVE_2_PARAMETERS);
        if (fromXML.size() != 1)
          std::cerr << "Could not read joint config (" << joint_name << " , " << fromXML.transpose()
                    << ")" << std::endl;
        else
        {
          SINCOS(fromXML[0], &config[joint.idx_q() + 1], &config[joint.idx_q() + 0]);
        }
      }

      template<typename JointModel>
      static void algo_impl(
        const JointModel & joint,
        const std::string & joint_name,
        const ConfigVectorType & fromXML,
        ConfigVectorType & config)
      {
        if (joint.nq() != fromXML.size())
          std::cerr << "Could not read joint config (" << joint_name << " , " << fromXML.transpose()
                    << ")" << std::endl;
        else
          config.segment(joint.idx_q(), joint.nq()) = fromXML;
      }
    };

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    void loadReferenceConfigurations(
      ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      const std::string & filename,
      const bool verbose)
    {
      // Check extension
      const std::string extension = filename.substr(filename.find_last_of('.') + 1);
      if (extension != "srdf")
      {
        const std::string exception_message(filename + " does not have the right extension.");
        throw std::invalid_argument(exception_message);
      }

      // Open file
      std::ifstream srdf_stream(filename.c_str());
      if (!srdf_stream.is_open())
      {
        const std::string exception_message(filename + " does not seem to be a valid file.");
        throw std::invalid_argument(exception_message);
      }

      loadReferenceConfigurationsFromXML(model, srdf_stream, verbose);
    }

    template<typename Scalar, int Options, template<typename, int> class JointCollectionTpl>
    void loadReferenceConfigurationsFromXML(
      ModelTpl<Scalar, Options, JointCollectionTpl> & model,
      std::istream & xmlStream,
      const bool verbose)
    {
      typedef ModelTpl<Scalar, Options, JointCollectionTpl> Model;
      typedef typename Model::JointModel JointModel;

      // Read xml stream
      using boost::property_tree::ptree;
      ptree pt;
      read_xml(xmlStream, pt);

      // Iterate over all tags directly children of robot
      BOOST_FOREACH (const ptree::value_type & v, pt.get_child("robot"))
      {
        // if we encounter a tag group_state
        if (v.first == "group_state")
        {
          const std::string name = v.second.get<std::string>("<xmlattr>.name");
          typename Model::ConfigVectorType ref_config(model.nq);
          neutral(model, ref_config);

          // Iterate over all the joint tags
          BOOST_FOREACH (const ptree::value_type & joint_tag, v.second)
          {
            if (joint_tag.first == "joint")
            {
              std::string joint_name = joint_tag.second.get<std::string>("<xmlattr>.name");
              typename Model::JointIndex joint_id = model.getJointId(joint_name);

              // Search in model the joint and its config id
              if (joint_id != model.joints.size()) // != model.njoints
              {
                const JointModel & joint = model.joints[joint_id];
                typename Model::ConfigVectorType joint_config(joint.nq());
                const std::string joint_val = joint_tag.second.get<std::string>("<xmlattr>.value");
                std::istringstream config_string(joint_val);
                std::vector<double> config_vec(
                  (std::istream_iterator<double>(config_string)), std::istream_iterator<double>());
                joint_config = Eigen::Map<Eigen::VectorXd>(
                  config_vec.data(), (Eigen::DenseIndex)config_vec.size());

                typedef LoadReferenceConfigurationStep<Scalar, Options, JointCollectionTpl>
                  LoadReferenceConfigurationStep_t;
                LoadReferenceConfigurationStep_t::run(
                  joint, typename LoadReferenceConfigurationStep_t::ArgsType(
                           joint_name, joint_config, ref_config));
                if (verbose)
                {
                  std::cout << "(" << joint_name << " , " << joint_config.transpose() << ")"
                            << std::endl;
                }
              }
              else
              {
                if (verbose)
                  std::cout << "The Joint " << joint_name << " was not found in model" << std::endl;
              }
            }
          }

          if (!model.referenceConfigurations.insert(std::make_pair(name, ref_config)).second)
          {
            //  Element already present...
            if (verbose)
              std::cout << "The reference configuration " << name
                        << " has been defined multiple times. "
                        << "Only the last instance of " << name << " is being used." << std::endl;
          }
        }
      } // BOOST_FOREACH
    }
  } // namespace srdf
} // namespace pinocchio

#endif // ifndef __pinocchio_parser_srdf_hxx__