buffer.cpp
Go to the documentation of this file.
1 /*********************************************************************
2  * BSD 3-Clause License
3  *
4  * Copyright (c) 2020 Northwestern University
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * * Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  *
13  * * Redistributions in binary form must reproduce the above copyright notice,
14  * this list of conditions and the following disclaimer in the documentation
15  * and/or other materials provided with the distribution.
16  *
17  * * Neither the name of the copyright holder nor the names of its
18  * contributors may be used to endorse or promote products derived from
19  * this software without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31  *********************************************************************/
39 #include <iostream>
40 
42 
43 namespace ergodic_exploration
44 {
45 using arma::distr_param;
46 using arma::ivec;
47 using arma::randi;
48 
49 ReplayBuffer::ReplayBuffer(unsigned int buffer_size, unsigned int batch_size)
50  : buffer_size_(buffer_size), batch_size_(batch_size)
51 {
52 }
53 
54 void ReplayBuffer::append(const vec& x)
55 {
56  if (memory_.size() < buffer_size_)
57  {
58  memory_.emplace(memory_.size(), x);
59  return;
60  }
61  std::cout << "WARNING: Buffer is full" << std::endl;
62 }
63 
64 mat ReplayBuffer::sampleMemory(const mat& xt) const
65 {
66  if (memory_.empty())
67  {
68  return xt;
69  }
70 
71  // total states
72  mat xt_total;
73 
74  // Concatenate the current store states with predicted trajectory
75  if (memory_.size() <= batch_size_)
76  {
77  const auto num_stored = memory_.size();
78  const auto num_states = xt.n_cols + num_stored;
79  xt_total.resize(xt.n_rows, num_states);
80 
81  for (unsigned int i = 0; i < num_stored; i++)
82  {
83  // Index is the key
84  xt_total.col(i) = memory_.at(i);
85  }
86 
87  // Copy predicted trajectory to end
88  xt_total.cols(num_stored, num_states - 1) = xt;
89  }
90 
91  // Randomly sample memory and concatenate with predicted trajectory
92  else
93  {
94  const auto num_states = xt.n_cols + batch_size_;
95  xt_total.resize(xt.n_rows, num_states);
96 
97  // random ints on interval [a b]
98  const ivec rand_ints = randi<ivec>(batch_size_, distr_param(0, memory_.size() - 1));
99 
100  for (unsigned int i = 0; i < batch_size_; i++)
101  {
102  // Index is the key
103  xt_total.col(i) = memory_.at(rand_ints(i));
104  }
105 
106  // Copy predicted trajectory to end
107  xt_total.cols(batch_size_, num_states - 1) = xt;
108  }
109 
110  return xt_total;
111 }
112 } // namespace ergodic_exploration
ergodic_exploration::ReplayBuffer::sampleMemory
mat sampleMemory(const mat &xt) const
Sample states from memory.
Definition: buffer.cpp:64
ergodic_exploration::ReplayBuffer::append
void append(const vec &x)
Add current state to memory.
Definition: buffer.cpp:54
ergodic_exploration::ReplayBuffer::memory_
std::unordered_map< unsigned int, vec > memory_
Definition: buffer.hpp:76
ergodic_exploration::ReplayBuffer::ReplayBuffer
ReplayBuffer(unsigned int buffer_size, unsigned int batch_size)
Constructor.
Definition: buffer.cpp:49
ergodic_exploration::ReplayBuffer::batch_size_
unsigned int batch_size_
Definition: buffer.hpp:75
ergodic_exploration
Definition: basis.hpp:43
ergodic_exploration::ReplayBuffer::buffer_size_
unsigned int buffer_size_
Definition: buffer.hpp:74
buffer.hpp
Stores past states.


ergodic_exploration
Author(s): bostoncleek
autogenerated on Wed Mar 2 2022 00:17:13