Browse Source

More Q learning... it does not seem effective

rparedis 4 years ago
parent
commit
a484368789
3 changed files with 71 additions and 108 deletions
  1. 30 19
      examples/AGV/AGVEnv.py
  2. 41 89
      examples/AGV/actions.csv
  3. BIN
      examples/AGV/data.npz

+ 30 - 19
examples/AGV/AGVEnv.py

@@ -58,11 +58,11 @@ class AGVEnv(gym.Env):
 		if action > 0:
 			self.same_actions = 0
 		self.last_action = r, d
-		if abs(r) < EPS or abs(d) < EPS:
+		if abs(r) < EPS or d < 0.0:
 			return self.states[-1], float('-inf'), True, {}
 		# ro, do = self.last_action
 		# reward = -np.power(ro - r, 2) - np.power(do - d, 2)
-		reward = self.same_actions * 100
+		reward = (self.same_actions ** 3) * 10
 		agv = AGVVirtual("AGV", r, d, "obtained.csv", initial=self.states[-1], v=0.033, T=35, Kp=-0.01)
 		sim = Simulator(agv)
 		sim.setDeltaT(0.2)
@@ -74,11 +74,11 @@ class AGVEnv(gym.Env):
 
 		moment = self.physical[self.physical["time"] <= self.time].iloc[-1]
 		offset = self.euclidean(moment["x"], moment["y"], state[0], state[1])
-		if offset > 0.1:
-			reward -= 1000
-		else:
-			reward -= offset
-			reward += self.euclidean(state[0], state[1], last_state[0], last_state[1]) ** 2
+		# if offset > 0.1:
+		# 	reward -= 10000
+		# else:
+		reward -= (offset * 100) ** 2
+		reward += self.euclidean(state[0], state[1], last_state[0], last_state[1])
 
 		TCP = agv.getBlockByName("TCP")
 		end_time = TCP.data[TCP.time_col][-1]
@@ -139,13 +139,20 @@ class AGVEnv(gym.Env):
 
 
 if __name__ == '__main__':
-	import random
+	import random, os
 	env = AGVEnv()
 
 	action_space_size = env.action_space.n
 	state_space_size = 100 * 100 * (6 * 360)
 
-	q_table = np.zeros((state_space_size, action_space_size))
+	if os.path.isfile("data.npz"):
+		file = np.load("data.npz")
+		q_table = file['Q']
+		exploration_rate = file['rate'][0]
+		print("READ FROM FILE")
+	else:
+		q_table = np.zeros((state_space_size, action_space_size))
+		exploration_rate = 1
 
 	num_episodes = 1000
 	max_steps_per_episode = 100 # but it won't go higher than 1
@@ -153,7 +160,7 @@ if __name__ == '__main__':
 	learning_rate = 0.1
 	discount_rate = 0.99
 
-	exploration_rate = 1
+	# exploration_rate = 1
 	max_exploration_rate = 1
 	min_exploration_rate = 0.01
 
@@ -168,7 +175,7 @@ if __name__ == '__main__':
 		for episode in range(num_episodes):
 			state = env.reset()
 			dstate = discretize(state)
-			env.label.set_text("Episode: " + str(episode))
+			env.label.set_text(f"Episode: {episode}\nR: {exploration_rate:.4f}")
 
 			done = False
 			rewards_current_episode = 0
@@ -176,7 +183,7 @@ if __name__ == '__main__':
 			for step in range(max_steps_per_episode):
 				env.render()
 
-				# Exploration -exploitation trade-off
+				# Exploration-exploitation trade-off
 				exploration_rate_threshold = random.uniform(0, 1)
 				if exploration_rate_threshold > exploration_rate:
 					action = np.argmax(q_table[dstate,:])
@@ -201,8 +208,8 @@ if __name__ == '__main__':
 			                   (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)
 
 			rewards_all_episodes.append(rewards_current_episode)
-	except:
-		print("ERROR!")
+	except BaseException as e:
+		print("ERROR!", e)
 
 	# Calculate and print the average reward per 10 episodes
 	rewards_per_thousand_episodes = np.split(np.array(rewards_all_episodes), num_episodes / 100)
@@ -213,11 +220,15 @@ if __name__ == '__main__':
 		print(count, ": ", str(sum(r / 100)))
 		count += 100
 
+	print("\n\n********** Vars **********\n")
+	print("Exploration Rate:", exploration_rate)
+
 	# Print updated Q-table
 	# print("\n\n********** Q-table **********\n")
 	# print(q_table)
-	np.save("Q.npy", q_table)
-	with open("actions.csv", 'w') as file:
-		file.write(f"r,d")
-		for r, d in env.actions:
-			file.write(f"{r:.3f},{d:.3f}\n")
+	np.savez_compressed("data.npz", Q=q_table, actions=np.asarray(env.actions), rate=np.array([exploration_rate]))
+	# np.save("Q.npy", q_table)
+	# with open("actions.csv", 'w') as file:
+	# 	file.write(f"r,d\n")
+	# 	for r, d in env.actions:
+	# 		file.write(f"{r:.3f},{d:.3f}\n")

+ 41 - 89
examples/AGV/actions.csv

@@ -1,89 +1,41 @@
-r,d0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
-0.018,0.211
+r,d
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.211
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.221
+0.018,0.231
+0.018,0.231
+0.018,0.241
+0.018,0.241
+0.018,0.251
+0.018,0.251
+0.018,0.261
+0.018,0.261
+0.018,0.271

BIN
examples/AGV/data.npz