qnet.py
Go to the documentation of this file.
1 """
2 Example of Q-table learning with a simple discretized 1-pendulum environment using a
3 linear Q network.
4 """
5 
6 import signal
7 import time
8 
9 import matplotlib.pyplot as plt
10 import numpy as np
11 import tensorflow as tf
12 from dpendulum import DPendulum
13 
14 
15 RANDOM_SEED = int((time.time() % 10) * 1000)
16 print("Seed = %d" % RANDOM_SEED)
17 np.random.seed(RANDOM_SEED)
18 tf.set_random_seed(RANDOM_SEED)
19 
20 
21 NEPISODES = 500 # Number of training episodes
22 NSTEPS = 50 # Max episode length
23 LEARNING_RATE = 0.1 # Step length in optimizer
24 DECAY_RATE = 0.99 # Discount factor
25 
26 
27 env = DPendulum()
28 NX = env.nx
29 NU = env.nu
30 
31 
32 
34  def __init__(self):
35  x = tf.placeholder(shape=[1, NX], dtype=tf.float32)
36  W = tf.Variable(tf.random_uniform([NX, NU], 0, 0.01, seed=100))
37  qvalue = tf.matmul(x, W)
38  u = tf.argmax(qvalue, 1)
39 
40  qref = tf.placeholder(shape=[1, NU], dtype=tf.float32)
41  loss = tf.reduce_sum(tf.square(qref - qvalue))
42  optim = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
43 
44  self.x = x # Network input
45  self.qvalue = qvalue # Q-value as a function of x
46  self.u = u # Policy as a function of x
47  self.qref = qref # Reference Q-value at next step (to be set to l+Q o f)
48  self.optim = optim # Optimizer
49 
50 
51 
52 tf.reset_default_graph()
53 qvalue = QValueNetwork()
54 sess = tf.InteractiveSession()
55 tf.global_variables_initializer().run()
56 
57 
58 def onehot(ix, n=NX):
59  """Return a vector which is 0 everywhere except index <i> set to 1."""
60  return np.array(
61  [
62  [(i == ix) for i in range(n)],
63  ],
64  np.float,
65  )
66 
67 
68 def disturb(u, i):
69  u += int(np.random.randn() * 10 / (i / 50 + 10))
70  return np.clip(u, 0, NU - 1)
71 
72 
73 def rendertrial(maxiter=100):
74  x = env.reset()
75  for i in range(maxiter):
76  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})
77  x, r = env.step(u)
78  env.render()
79  if r == 1:
80  print("Reward!")
81  break
82 
83 
84 signal.signal(
85  signal.SIGTSTP, lambda x, y: rendertrial()
86 ) # Roll-out when CTRL-Z is pressed
87 
88 
89 h_rwd = [] # Learning history (for plot).
90 
91 
92 for episode in range(1, NEPISODES):
93  x = env.reset()
94  rsum = 0.0
95 
96  for step in range(NSTEPS - 1):
97  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})[0] # Greedy policy ...
98  u = disturb(u, episode) # ... with noise
99  x2, reward = env.step(u)
100 
101  # Compute reference Q-value at state x respecting HJB
102  Q2 = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x2)})
103  Qref = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x)})
104  Qref[0, u] = reward + DECAY_RATE * np.max(Q2)
105 
106  # Update Q-table to better fit HJB
107  sess.run(qvalue.optim, feed_dict={qvalue.x: onehot(x), qvalue.qref: Qref})
108 
109  rsum += reward
110  x = x2
111  if reward == 1:
112  break
113 
114  h_rwd.append(rsum)
115  if not episode % 20:
116  print("Episode #%d done with %d sucess" % (episode, sum(h_rwd[-20:])))
117 
118 print("Total rate of success: %.3f" % (sum(h_rwd) / NEPISODES))
119 rendertrial()
120 plt.plot(np.cumsum(h_rwd) / range(1, NEPISODES))
121 plt.show()
dpendulum.DPendulum
Definition: dpendulum.py:74
omniidl_be_python_with_docstring.run
def run(tree, args)
Definition: cmake/hpp/idl/omniidl_be_python_with_docstring.py:140
qnet.QValueNetwork.optim
optim
Definition: qnet.py:48
qnet.QValueNetwork.__init__
def __init__(self)
Definition: qnet.py:34
qnet.QValueNetwork.u
u
Definition: qnet.py:46
qnet.QValueNetwork
— Q-value networks
Definition: qnet.py:33
qnet.QValueNetwork.qref
qref
Definition: qnet.py:47
qnet.QValueNetwork.x
x
Definition: qnet.py:44
qnet.rendertrial
def rendertrial(maxiter=100)
Definition: qnet.py:73
qnet.disturb
def disturb(u, i)
Definition: qnet.py:68
qnet.QValueNetwork.qvalue
qvalue
Definition: qnet.py:45
qnet.onehot
def onehot(ix, n=NX)
Definition: qnet.py:58


pinocchio
Author(s):
autogenerated on Wed Dec 25 2024 03:41:18