qnet.py
Go to the documentation of this file.
1 '''
2 Example of Q-table learning with a simple discretized 1-pendulum environment using a linear Q network.
3 '''
4 
5 import numpy as np
6 import random
7 import tensorflow as tf
8 import matplotlib.pyplot as plt
9 from dpendulum import DPendulum
10 import signal
11 import time
12 
13 
14 RANDOM_SEED = int((time.time()%10)*1000)
15 print "Seed = %d" % RANDOM_SEED
16 np.random.seed(RANDOM_SEED)
17 tf.set_random_seed(RANDOM_SEED)
18 
19 
20 NEPISODES = 500 # Number of training episodes
21 NSTEPS = 50 # Max episode length
22 LEARNING_RATE = 0.1 # Step length in optimizer
23 DECAY_RATE = 0.99 # Discount factor
24 
25 
26 env = DPendulum()
27 NX = env.nx
28 NU = env.nu
29 
30 
32  def __init__(self):
33  x = tf.placeholder(shape=[1,NX],dtype=tf.float32)
34  W = tf.Variable(tf.random_uniform([NX,NU],0,0.01,seed=100))
35  qvalue = tf.matmul(x,W)
36  u = tf.argmax(qvalue,1)
37 
38  qref = tf.placeholder(shape=[1,NU],dtype=tf.float32)
39  loss = tf.reduce_sum(tf.square(qref - qvalue))
40  optim = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
41 
42  self.x = x # Network input
43  self.qvalue = qvalue # Q-value as a function of x
44  self.u = u # Policy as a function of x
45  self.qref = qref # Reference Q-value at next step (to be set to l+Q o f)
46  self.optim = optim # Optimizer
47 
48 
49 tf.reset_default_graph()
50 qvalue = QValueNetwork()
51 sess = tf.InteractiveSession()
52 tf.global_variables_initializer().run()
53 
54 def onehot(ix,n=NX):
55  '''Return a vector which is 0 everywhere except index <i> set to 1.'''
56  return np.array([[ (i==ix) for i in range(n) ],],np.float)
57 
58 def disturb(u,i):
59  u += int(np.random.randn()*10/(i/50+10))
60  return np.clip(u,0,NU-1)
61 
62 def rendertrial(maxiter=100):
63  x = env.reset()
64  for i in range(maxiter):
65  u = sess.run(qvalue.u,feed_dict={ qvalue.x:onehot(x) })
66  x,r = env.step(u)
67  env.render()
68  if r==1: print 'Reward!'; break
69 signal.signal(signal.SIGTSTP, lambda x,y:rendertrial()) # Roll-out when CTRL-Z is pressed
70 
71 
72 h_rwd = [] # Learning history (for plot).
73 
74 
75 for episode in range(1,NEPISODES):
76  x = env.reset()
77  rsum = 0.0
78 
79  for step in range(NSTEPS-1):
80  u = sess.run(qvalue.u,feed_dict={ qvalue.x: onehot(x) })[0] # Greedy policy ...
81  u = disturb(u,episode) # ... with noise
82  x2,reward = env.step(u)
83 
84  # Compute reference Q-value at state x respecting HJB
85  Q2 = sess.run(qvalue.qvalue,feed_dict={ qvalue.x: onehot(x2) })
86  Qref = sess.run(qvalue.qvalue,feed_dict={ qvalue.x: onehot(x ) })
87  Qref[0,u] = reward + DECAY_RATE*np.max(Q2)
88 
89  # Update Q-table to better fit HJB
90  sess.run(qvalue.optim,feed_dict={ qvalue.x : onehot(x),
91  qvalue.qref : Qref })
92 
93  rsum += reward
94  x = x2
95  if reward == 1: break
96 
97  h_rwd.append(rsum)
98  if not episode%20: print 'Episode #%d done with %d sucess' % (episode,sum(h_rwd[-20:]))
99 
100 print "Total rate of success: %.3f" % (sum(h_rwd)/NEPISODES)
101 rendertrial()
102 plt.plot( np.cumsum(h_rwd)/range(1,NEPISODES) )
103 plt.show()
104 
def disturb(u, i)
Definition: qnet.py:58
def rendertrial(maxiter=100)
Definition: qnet.py:62
def __init__(self)
Definition: qnet.py:32
def onehot(ix, n=NX)
Definition: qnet.py:54
— Q-value networks
Definition: qnet.py:31


pinocchio
Author(s):
autogenerated on Fri Jun 23 2023 02:38:32