qnet.py
Go to the documentation of this file.
1 """
2 Example of Q-table learning with a simple discretized 1-pendulum environment using a
3 linear Q network.
4 """
5 
6 import signal
7 import time
8 
9 import matplotlib.pyplot as plt
10 import numpy as np
11 import tensorflow as tf
12 from dpendulum import DPendulum
13 
14 # --- Random seed
15 RANDOM_SEED = int((time.time() % 10) * 1000)
16 print(f"Seed = {RANDOM_SEED}")
17 np.random.seed(RANDOM_SEED)
18 tf.set_random_seed(RANDOM_SEED)
19 
20 # --- Hyper paramaters
21 NEPISODES = 500 # Number of training episodes
22 NSTEPS = 50 # Max episode length
23 LEARNING_RATE = 0.1 # Step length in optimizer
24 DECAY_RATE = 0.99 # Discount factor
25 
26 # --- Environment
27 env = DPendulum()
28 NX = env.nx
29 NU = env.nu
30 
31 
32 # --- Q-value networks
34  def __init__(self):
35  x = tf.placeholder(shape=[1, NX], dtype=tf.float32)
36  W = tf.Variable(tf.random_uniform([NX, NU], 0, 0.01, seed=100))
37  qvalue = tf.matmul(x, W)
38  u = tf.argmax(qvalue, 1)
39 
40  qref = tf.placeholder(shape=[1, NU], dtype=tf.float32)
41  loss = tf.reduce_sum(tf.square(qref - qvalue))
42  optim = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
43 
44  self.x = x # Network input
45  self.qvalue = qvalue # Q-value as a function of x
46  self.u = u # Policy as a function of x
47  # Reference Q-value at next step (to be set to l+Q o f)
48  self.qref = qref
49  self.optim = optim # Optimizer
50 
51 
52 # --- Tensor flow initialization
53 tf.reset_default_graph()
54 qvalue = QValueNetwork()
55 sess = tf.InteractiveSession()
56 tf.global_variables_initializer().run()
57 
58 
59 def onehot(ix, n=NX):
60  """Return a vector which is 0 everywhere except index <i> set to 1."""
61  return np.array(
62  [
63  [(i == ix) for i in range(n)],
64  ],
65  np.float,
66  )
67 
68 
69 def disturb(u, i):
70  u += int(np.random.randn() * 10 / (i / 50 + 10))
71  return np.clip(u, 0, NU - 1)
72 
73 
74 def rendertrial(maxiter=100):
75  x = env.reset()
76  for i in range(maxiter):
77  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})
78  x, r = env.step(u)
79  env.render()
80  if r == 1:
81  print("Reward!")
82  break
83 
84 
85 signal.signal(
86  signal.SIGTSTP, lambda x, y: rendertrial()
87 ) # Roll-out when CTRL-Z is pressed
88 
89 # --- History of search
90 h_rwd = [] # Learning history (for plot).
91 
92 # --- Training
93 for episode in range(1, NEPISODES):
94  x = env.reset()
95  rsum = 0.0
96 
97  for step in range(NSTEPS - 1):
98  # Greedy policy ...
99  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})[0]
100  u = disturb(u, episode) # ... with noise
101  x2, reward = env.step(u)
102 
103  # Compute reference Q-value at state x respecting HJB
104  Q2 = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x2)})
105  Qref = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x)})
106  Qref[0, u] = reward + DECAY_RATE * np.max(Q2)
107 
108  # Update Q-table to better fit HJB
109  sess.run(qvalue.optim, feed_dict={qvalue.x: onehot(x), qvalue.qref: Qref})
110 
111  rsum += reward
112  x = x2
113  if reward == 1:
114  break
115 
116  h_rwd.append(rsum)
117  if not episode % 20:
118  print(f"Episode #{episode} done with {sum(h_rwd[-20:])} sucess")
119 
120 print(f"Total rate of success: {sum(h_rwd) / NEPISODES:.3f}")
121 rendertrial()
122 plt.plot(np.cumsum(h_rwd) / range(1, NEPISODES))
123 plt.show()
dpendulum.DPendulum
Definition: dpendulum.py:74
omniidl_be_python_with_docstring.run
def run(tree, args)
Definition: cmake/hpp/idl/omniidl_be_python_with_docstring.py:140
qnet.QValueNetwork.optim
optim
Definition: qnet.py:49
qnet.QValueNetwork.__init__
def __init__(self)
Definition: qnet.py:34
qnet.QValueNetwork.u
u
Definition: qnet.py:46
qnet.QValueNetwork
Definition: qnet.py:33
qnet.QValueNetwork.qref
qref
Definition: qnet.py:48
qnet.QValueNetwork.x
x
Definition: qnet.py:44
qnet.rendertrial
def rendertrial(maxiter=100)
Definition: qnet.py:74
qnet.disturb
def disturb(u, i)
Definition: qnet.py:69
qnet.QValueNetwork.qvalue
qvalue
Definition: qnet.py:45
qnet.onehot
def onehot(ix, n=NX)
Definition: qnet.py:59


pinocchio
Author(s):
autogenerated on Wed Apr 16 2025 02:41:50