Program Listing for File jwt.hpp

Return to documentation for file (include/jwt/jwt.hpp)

/*
Copyright (c) 2017 Arun Muralidharan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
 */

#ifndef JWT_HPP
#define JWT_HPP

#include <set>
#include <array>
#include <string>
#include <chrono>
#include <ostream>
#include <cassert>
#include <cstring>

#include "jwt/assertions.hpp"
#include "jwt/base64.hpp"
#include "jwt/config.hpp"
#include "jwt/algorithm.hpp"
#include "jwt/string_view.hpp"
#include "jwt/parameters.hpp"
#include "jwt/exceptions.hpp"
#include "json.hpp"

// For convenience
using json_t = nlohmann::json;
using system_time_t = std::chrono::time_point<std::chrono::system_clock>;
namespace json_ns = nlohmann;

namespace jwt {

enum class type
{
  NONE = 0,
  JWT  = 1,
};

inline enum type str_to_type(const jwt::string_view typ) noexcept
{
  assert (typ.length() && "Empty type string");

  if (!strcasecmp(typ.data(), "jwt")) return type::JWT;
  else if(!strcasecmp(typ.data(), "none")) return type::NONE;

  JWT_NOT_REACHED("Code not reached");
  return type::NONE;
}


inline jwt::string_view type_to_str(SCOPED_ENUM type typ)
{
  switch (typ) {
    case type::JWT: return "JWT";
    default:        assert (0 && "Unknown type");
  };

  JWT_NOT_REACHED("Code not reached");
}


enum class registered_claims
{
  // Expiration Time claim
  expiration = 0,
  // Not Before Time claim
  not_before,
  // Issuer name claim
  issuer,
  // Audience claim
  audience,
  // Issued At Time claim
  issued_at,
  // Subject claim
  subject,
  // JWT ID claim
  jti,
};


inline jwt::string_view reg_claims_to_str(SCOPED_ENUM registered_claims claim) noexcept
{
  switch (claim) {
    case registered_claims::expiration: return "exp";
    case registered_claims::not_before: return "nbf";
    case registered_claims::issuer:     return "iss";
    case registered_claims::audience:   return "aud";
    case registered_claims::issued_at:  return "iat";
    case registered_claims::subject:    return "sub";
    case registered_claims::jti:        return "jti";
    default:                            assert (0 && "Not a registered claim");
  };
  JWT_NOT_REACHED("Code not reached");
  return "";
}

struct jwt_set
{
  struct case_compare
  {
    using is_transparent = std::true_type;

    bool operator()(const std::string& lhs, const std::string& rhs) const
    {
      int ret = strcmp(lhs.c_str(), rhs.c_str());
      return (ret < 0);
    }

    bool operator()(const jwt::string_view lhs, const jwt::string_view rhs) const
    {
      int ret = strcmp(lhs.data(), rhs.data());
      return (ret < 0);
    }

    bool operator()(const std::string& lhs, const jwt::string_view rhs) const
    {
      int ret = strcmp(lhs.data(), rhs.data());
      return (ret < 0);
    }

    bool operator()(const jwt::string_view lhs, const std::string& rhs) const
    {
      int ret = strcmp(lhs.data(), rhs.data());
      return (ret < 0);
    }
  };

  using header_claim_set_t = std::set<std::string, case_compare>;
};

// Fwd declaration for friend functions to specify the
// default arguments
// See: https://stackoverflow.com/a/23336823/434233
template <typename T, typename = typename std::enable_if<
            detail::meta::has_create_json_obj_member<T>{}>::type>
std::string to_json_str(const T& obj, bool pretty=false);

template <typename T>
std::ostream& write(std::ostream& os, const T& obj, bool pretty=false);

template <typename T,
          typename = typename std::enable_if<
                      detail::meta::has_create_json_obj_member<T>{}>::type
         >
std::ostream& operator<< (std::ostream& os, const T& obj);


struct write_interface
{
  template <typename T, typename Cond>
  friend std::string to_json_str(const T& obj, bool pretty);

  template <typename T>
  friend std::ostream& write(
      std::ostream& os, const T& obj, bool pretty);

  template <typename T, typename Cond>
  friend std::ostream& operator<< (std::ostream& os, const T& obj);
};

template <typename Derived>
struct base64_enc_dec
{
  std::string base64_encode(bool with_pretty = false) const
  {
    std::string jstr = to_json_str(*static_cast<const Derived*>(this), with_pretty);
    std::string b64_str = jwt::base64_encode(jstr.c_str(), jstr.length());
    // Do the URI safe encoding
    auto new_len = jwt::base64_uri_encode(&b64_str[0], b64_str.length());
    b64_str.resize(new_len);

    return b64_str;
  }

  std::string base64_decode(const jwt::string_view encoded_str)
  {
    return jwt::base64_uri_decode(encoded_str.data(), encoded_str.length());
  }

};


struct jwt_header: write_interface
                 , base64_enc_dec<jwt_header>
{
public: // 'tors
  /*
   * Default constructor.
   */
  jwt_header()
  {
    payload_["alg"] = "none";
    payload_["typ"] = "JWT";
  }

  jwt_header(SCOPED_ENUM algorithm alg, SCOPED_ENUM type typ = type::JWT)
    : alg_(alg)
    , typ_(typ)
  {
    payload_["typ"] = std::string(type_to_str(typ_));
    payload_["alg"] = std::string(alg_to_str(alg_));
  }

  jwt_header(const jwt::string_view enc_str)
  {
    this->decode(enc_str);
  }

  jwt_header(const jwt_header&) = default;
  jwt_header& operator=(const jwt_header&) = default;

  ~jwt_header() = default;

public: // Exposed APIs
  void algo(SCOPED_ENUM algorithm alg)
  {
    alg_ = alg;
    payload_["alg"] = std::string(alg_to_str(alg_));
  }

  void algo(const jwt::string_view sv)
  {
    alg_ = str_to_alg(sv.data());
    payload_["alg"] = std::string(alg_to_str(alg_));
  }

  SCOPED_ENUM algorithm algo() const noexcept
  {
    return alg_;
  }

  void typ(SCOPED_ENUM type typ) noexcept
  {
    typ_ = typ;
    payload_["typ"] = std::string(type_to_str(typ_));
  }

  void typ(const jwt::string_view sv)
  {
    typ_ = str_to_type(sv.data());
    payload_["typ"] = std::string(type_to_str(typ_));
  }

  SCOPED_ENUM type typ() const noexcept
  {
    return typ_;
  }

  template <typename T,
            typename=std::enable_if_t<
                      !std::is_same<jwt::string_view, std::decay_t<T>>::value
                     >
           >
  bool add_header(const jwt::string_view hname, T&& hvalue, bool overwrite=false)
  {
    auto itr = headers_.find(hname);
    if (itr != std::end(headers_) && !overwrite) {
      return false;
    }

    headers_.emplace(hname.data(), hname.length());
    payload_[hname.data()] = std::forward<T>(hvalue);

    return true;
  }

  bool add_header(const jwt::string_view cname, const jwt::string_view cvalue, bool overwrite=false)
  {
    return add_header(cname,
                      std::string{cvalue.data(), cvalue.length()},
                      overwrite);
  }

  bool remove_header(const jwt::string_view hname)
  {
    if (!strcasecmp(hname.data(), "typ")) {
      typ_ = type::NONE;
      payload_.erase(hname.data());
      return true;
    }

    auto itr = headers_.find(hname);
    if (itr == std::end(headers_)) {
      return false;
    }
    payload_.erase(hname.data());
    headers_.erase(hname.data());

    return true;
  }

  bool has_header(const jwt::string_view hname)
  {
    if (!strcasecmp(hname.data(), "typ")) return typ_ != type::NONE;
    return headers_.find(hname) != std::end(headers_);
  }


  //TODO: error code ?
  std::string encode(bool pprint = false)
  {
    return base64_encode(pprint);
  }

  void decode(const jwt::string_view enc_str, std::error_code& ec);

  void decode(const jwt::string_view enc_str);

  const json_t& create_json_obj() const
  {
    return payload_;
  }

private: // Data members
  SCOPED_ENUM algorithm alg_ = algorithm::NONE;

  SCOPED_ENUM type      typ_ = type::JWT;

  // The JSON payload object
  json_t payload_;

  //Extra headers for JWS
  jwt_set::header_claim_set_t headers_;
};


struct jwt_payload: write_interface
                  , base64_enc_dec<jwt_payload>
{
public: // 'tors
  jwt_payload() = default;

  jwt_payload(const jwt::string_view enc_str)
  {
    this->decode(enc_str);
  }

  jwt_payload(const jwt_payload&) = default;
  jwt_payload& operator=(const jwt_payload&) = default;

  ~jwt_payload() = default;

public: // Exposed APIs
  template <typename T,
            typename=typename std::enable_if_t<
              !std::is_same<system_time_t, std::decay_t<T>>::value ||
              !std::is_same<jwt::string_view, std::decay_t<T>>::value
              >
           >
  bool add_claim(const jwt::string_view cname, T&& cvalue, bool overwrite=false)
  {
    // Duplicate claim names not allowed
    // if overwrite flag is set to true.
    auto itr = claim_names_.find(cname);
    if (itr != claim_names_.end() && !overwrite) {
      return false;
    }

    // Add it to the known set of claims
    claim_names_.emplace(cname.data(), cname.length());

    //Add it to the json payload
    payload_[cname.data()] = std::forward<T>(cvalue);

    return true;
  }

  bool add_claim(const jwt::string_view cname, const jwt::string_view cvalue, bool overwrite=false)
  {
    return add_claim(cname, std::string{cvalue.data(), cvalue.length()}, overwrite);
  }

  bool add_claim(const jwt::string_view cname, system_time_t tp, bool overwrite=false)
  {
    return add_claim(
        cname,
        std::chrono::duration_cast<
          std::chrono::seconds>(tp.time_since_epoch()).count(),
        overwrite
        );
  }

  template <typename T,
            typename=std::enable_if_t<
                      !std::is_same<std::decay_t<T>, system_time_t>::value ||
                      !std::is_same<std::decay_t<T>, jwt::string_view>::value
                     >>
  bool add_claim(SCOPED_ENUM registered_claims cname, T&& cvalue, bool overwrite=false)
  {
    return add_claim(
        reg_claims_to_str(cname),
        std::forward<T>(cvalue),
        overwrite
        );
  }

  bool add_claim(SCOPED_ENUM registered_claims cname, system_time_t tp, bool overwrite=false)
  {
    return add_claim(
        reg_claims_to_str(cname),
        std::chrono::duration_cast<
          std::chrono::seconds>(tp.time_since_epoch()).count(),
        overwrite
        );
  }

  bool add_claim(SCOPED_ENUM registered_claims cname, jwt::string_view cvalue, bool overwrite=false)
  {
    return add_claim(
          reg_claims_to_str(cname),
          std::string{cvalue.data(), cvalue.length()},
          overwrite
        );
  }

  template <typename T>
  decltype(auto) get_claim_value(const jwt::string_view cname) const
  {
    return payload_[cname.data()].get<T>();
  }

  template <typename T>
  decltype(auto) get_claim_value(SCOPED_ENUM registered_claims cname) const
  {
    return get_claim_value<T>(reg_claims_to_str(cname));
  }

  bool remove_claim(const jwt::string_view cname)
  {
    auto itr = claim_names_.find(cname);
    if (itr == claim_names_.end()) return false;

    claim_names_.erase(itr);
    payload_.erase(cname.data());

    return true;
  }

  bool remove_claim(SCOPED_ENUM registered_claims cname)
  {
    return remove_claim(reg_claims_to_str(cname));
  }

  //TODO: Not all libc++ version agrees with this
  //because count() is not made const for is_transparent
  //based overload
  bool has_claim(const jwt::string_view cname) const noexcept
  {
    return claim_names_.find(cname) != std::end(claim_names_);
  }

  bool has_claim(SCOPED_ENUM registered_claims cname) const noexcept
  {
    return has_claim(reg_claims_to_str(cname));
  }

  template <typename T>
  bool has_claim_with_value(const jwt::string_view cname, T&& cvalue) const
  {
    auto itr = claim_names_.find(cname);
    if (itr == claim_names_.end()) return false;

    return (cvalue == payload_[cname.data()]);
  }

  template <typename T>
  bool has_claim_with_value(const SCOPED_ENUM registered_claims cname, T&& value) const
  {
    return has_claim_with_value(reg_claims_to_str(cname), std::forward<T>(value));
  }

  std::string encode(bool pprint = false)
  {
    return base64_encode(pprint);
  }

  void decode(const jwt::string_view enc_str, std::error_code& ec);

  void decode(const jwt::string_view enc_str);

  const json_t& create_json_obj() const
  {
    return payload_;
  }

private:

  json_t payload_;
  jwt_set::header_claim_set_t claim_names_;
};

struct jwt_signature
{
public: // 'tors
  jwt_signature() = default;

  jwt_signature(const jwt::string_view key)
    : key_(key.data(), key.length())
  {
  }

  jwt_signature(const jwt_signature&) = default;
  jwt_signature& operator=(const jwt_signature&) = default;

  ~jwt_signature() = default;

public: // Exposed APIs
  std::string encode(const jwt_header& header,
                     const jwt_payload& payload,
                     std::error_code& ec);

  verify_result_t verify(const jwt_header& header,
              const jwt::string_view hdr_pld_sign,
              const jwt::string_view jwt_sign);

private: // Private implementation
  sign_func_t get_sign_algorithm_impl(const jwt_header& hdr) const noexcept;

  verify_func_t get_verify_algorithm_impl(const jwt_header& hdr) const noexcept;

private: // Data members;

  std::string key_;
};


class jwt_object
{
public: // 'tors
  jwt_object() = default;

  template <typename First, typename... Rest,
            typename=std::enable_if_t<detail::meta::is_parameter_concept<First>::value>>
  jwt_object(First&& first, Rest&&... rest);

public: // Exposed static APIs
  static std::array<jwt::string_view, 3>
  three_parts(const jwt::string_view enc_str);

public: // Exposed APIs
  jwt_payload& payload() noexcept
  {
    return payload_;
  }

  const jwt_payload& payload() const noexcept
  {
    return payload_;
  }

  void payload(const jwt_payload& p)
  {
    payload_ = p;
  }

  void payload(jwt_payload&& p)
  {
    payload_ = std::move(p);
  }

  void header(const jwt_header& h)
  {
    header_ = h;
  }

  void header(jwt_header&& h)
  {
    header_ = std::move(h);
  }

  jwt_header& header() noexcept
  {
    return header_;
  }

  const jwt_header& header() const noexcept
  {
    return header_;
  }

  std::string secret() const
  {
    return secret_;
  }

  void secret(const jwt::string_view sv)
  {
    secret_.assign(sv.data(), sv.length());
  }

  template <typename T,
            typename=typename std::enable_if_t<
              !std::is_same<system_time_t, std::decay_t<T>>::value>>
  jwt_object& add_claim(const jwt::string_view name, T&& value)
  {
    payload_.add_claim(name, std::forward<T>(value));
    return *this;
  }

  jwt_object& add_claim(const jwt::string_view name, system_time_t time_point);

  template <typename T>
  jwt_object& add_claim(SCOPED_ENUM registered_claims cname, T&& value)
  {
    return add_claim(reg_claims_to_str(cname), std::forward<T>(value));
  }

  jwt_object& remove_claim(const jwt::string_view name);

  jwt_object& remove_claim(SCOPED_ENUM registered_claims cname)
  {
    return remove_claim(reg_claims_to_str(cname));
  }

  bool has_claim(const jwt::string_view cname) const noexcept
  {
    return payload().has_claim(cname);
  }

  bool has_claim(SCOPED_ENUM registered_claims cname) const noexcept
  {
    return payload().has_claim(cname);
  }

  std::string signature(std::error_code& ec) const;

  std::string signature() const;

  template <typename Params, typename SequenceT>
  std::error_code verify(
      const Params& dparams,
      const params::detail::algorithms_param<SequenceT>& algos) const;

private: // private APIs
  template <typename... Args>
  void set_parameters(Args&&... args);

  template <typename M, typename... Rest>
  void set_parameters(params::detail::payload_param<M>&&, Rest&&...);

  template <typename... Rest>
  void set_parameters(params::detail::secret_param, Rest&&...);

  template <typename... Rest>
  void set_parameters(params::detail::algorithm_param, Rest&&...);

  template <typename M, typename... Rest>
  void set_parameters(params::detail::headers_param<M>&&, Rest&&...);

  void set_parameters();

public: //TODO: Not good
  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::secret_param s, Rest&&... args);

  template <typename DecodeParams, typename T, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::secret_function_param<T>&& s, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::leeway_param l, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::verify_param v, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::issuer_param i, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::audience_param a, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::subject_param a, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::validate_iat_param v, Rest&&... args);

  template <typename DecodeParams, typename... Rest>
  static void set_decode_params(DecodeParams& dparams, params::detail::validate_jti_param v, Rest&&... args);

  template <typename DecodeParams>
  static void set_decode_params(DecodeParams& dparams);

private: // Data Members

  jwt_header header_;

  jwt_payload payload_;

  std::string secret_;
};

template <typename SequenceT, typename... Args>
jwt_object decode(const jwt::string_view enc_str,
                  const params::detail::algorithms_param<SequenceT>& algos,
                  std::error_code& ec,
                  Args&&... args);

template <typename SequenceT, typename... Args>
jwt_object decode(const jwt::string_view enc_str,
                  const params::detail::algorithms_param<SequenceT>& algos,
                  Args&&... args);


} // END namespace jwt


#include "jwt/impl/jwt.ipp"

#endif