-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathsnakeGameGATest.py
219 lines (162 loc) · 7.56 KB
/
snakeGameGATest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#**************************************************************************************
#snakeGameGATest.py
#Author: Craig Haber
#5/9/2020
#Module with the SnakeGameGATest class that is instantiated in testTrainedAgents.py
#to observe the best agents that were trained with the genetic algorithm.
#*************************************************************************************
import pygame
import random
import collections
from helpers.snakeGame import SnakeGame
from helpers.snake import Snake
from helpers import neuralNetwork as nn
class SnakeGameGATest(SnakeGame):
"""Class framework to observe agents who were trained with the genetic algortihm to play the Snake Game.
Inherits the SankeGame class that runs the Snake Game.
Attributes:
self.frames_since_ladt_fruit: The number of frames since the last fruit was eaten by a snake.
self.bits_per_weight: The number of bits per each weight in the nueral network.
self.num_inputs: The number of inputs in the neural network.
self.num_hidden_layer_nodes: The number of nodes per each of the 2 hidden layers in the neural network.
self.num_ouputs: The number of outputs in the neural network.
self.weights: The weights for the neural network converted from the chromosome bit sequence of the agent.
"""
def __init__(self, fps, chromosome, bits_per_weight, num_inputs, num_hidden_layer_nodes, num_outputs):
"""Initializes the SnakeGameGATest class.
The only agrument that is not a documented class attribute is:
chromosome: A string of bits representing all of the weights for the neural network.
"""
super().__init__(fps)
self.frames_since_last_fruit = 0
self.bits_per_weight = bits_per_weight
self.num_inputs = num_inputs
self.num_hidden_layer_nodes = num_hidden_layer_nodes
self.num_outputs = num_outputs
#chromsome will be an empty string if this class was inhereted from the class SnakeGameGATrain
#This is because there will be a population of chromosomes, and not just one chromosome to test
if chromosome != "":
self.weights = nn.mapChrom2Weights(chromosome, self.bits_per_weight, self.num_inputs, self.num_hidden_layer_nodes, self.num_outputs)
def move_snake(self):
"""Function that determines where snake should move next based on the nueral network.
This overrides the method in the SnakeGame superclass.
"""
head = self.snake.body[0]
#Get the manhattan ditance of the fruit from the head if it moves in each direction
dist_left_fruit = self.manhattan_distance(head[0],head[1]-1)
dist_up_fruit = self.manhattan_distance(head[0]-1,head[1])
dist_right_fruit = self.manhattan_distance(head[0],head[1]+1)
dist_down_fruit = self.manhattan_distance(head[0]+1, head[1])
#Calculate the space available for turning in each of the four directions, reduced by a constant factor
constant = 20
open_spaces_left = self.calc_open_spaces((head[0], head[1]-1))/constant
open_spaces_up = self.calc_open_spaces((head[0]-1, head[1]))/constant
open_spaces_right = self.calc_open_spaces((head[0], head[1]+1))/constant
open_spaces_down = self.calc_open_spaces((head[0]+1, head[1]))/constant
#Get the length of the snake
length = self.score + 1
network_inputs = [dist_left_fruit, dist_up_fruit, dist_right_fruit, dist_down_fruit, open_spaces_left, open_spaces_up, open_spaces_down, open_spaces_right, length]
#Get all of the outputs from the neural network indicating a value of "goodness" for turning in each direction
outputs = nn.testNetwork(network_inputs, self.weights, self.num_hidden_layer_nodes, self.num_outputs)
#Get the maximum of all the ouputs, and this is the direction to turn
max_output = max(outputs)
#Systematically decide which direction to turn based on the max output
if max_output == outputs[0]:
direct = "left"
elif max_output == outputs[1]:
direct = "up"
elif max_output == outputs[2]:
direct = "right"
else:
direct = "down"
self.snake.directions.appendleft(direct)
if len(self.snake.directions) > len(self.snake.body):
self.snake.directions.pop()
self.snake.update_body_positions()
def manhattan_distance(self, y_head, x_head):
"""Function to calculate the manhattan distance between the fruit and the snake's head
Arguments:
y_head: The row in the grid of the snake's head.
x_head: The column in the grid of the snake's head.
Returns:
The manhattan distance between the fruit and the snake's head.
"""
return abs(self.fruit_pos[0] - y_head) + abs(self.fruit_pos[1] - x_head)
def calc_open_spaces(self,start_pos):
"""Function to calculate the number of open spaces around the snake
An open space is a space that the snake can reach without being blocked off by
the wall or its own body.
Arguments:
start_poistion: A tuple in (row,column) format representing a position of the snake's head
Returns:
An integer of how many open spaces are available.
"""
open_spaces = 0
start_y = start_pos[1]
start_x = start_pos[0]
#If the start position is in the snake's body or out of bounds
if start_pos in self.snake.body or (start_x < 0 or start_x >= self.cols or start_y < 0 or start_y >= self.rows):
#no open spaces
return 0
#Breadth first search is used
#Create a set to represent th visited spaces
visited = set([start_pos])
#Create a queue to keep track of which spaces need to be expanded
queue = collections.deque([start_pos])
#While there are still unvisited open spaces to search from
while len(queue) > 0:
cur = queue.popleft()
possible_moves = self.get_possible_moves(cur)
for move in possible_moves:
if move not in visited:
visited.add(move)
#if the move is an open space
if move not in self.snake.body:
open_spaces +=1
#add the open space to the queue for further searching
queue.append(move)
return open_spaces
def get_possible_moves(self,cur):
"""Function to get all the possible adjacent moves from a position.
The function is called from calc_open_spaces() during the breadth first search.
Arguments:
cur: A tuple in (row,column) format representing the position
to get the next possible moves from.
Returns:
A list containing (row,column) tuples of all the possible adjacent moves.
"""
adjacent_spaces = [(cur[0], cur[1]-1), (cur[0]-1,cur[1]), (cur[0], cur[1]+1), (cur[0]+1, cur[1])]
possible_moves = []
for move in adjacent_spaces:
move_y = move[1]
move_x = move[0]
#If the move is not out of bounds
if move_x >= 0 and move_x < self.cols and move_y >= 0 and move_y < self.rows:
possible_moves.append(move)
return possible_moves
def check_fruit_collision(self):
"""Function that detects and handles if the snake has collided with a fruit.
This overrides the method in the SnakeGame superclass."""
#If we found a fruit
if self.snake.body[0] == self.fruit_pos:
#Add the new body square to the tail of the snake
self.snake.extend_snake()
#Generate a new fruit in a random position
self.generate_fruit()
self.score += 1
self.frames_since_last_fruit = 0
def update_frames_since_last_fruit(self):
"""Function to check if the snake needs to be killed for not eating a fruit in a while."""
self.frames_since_last_fruit += 1
if (self.frames_since_last_fruit == 50 and self.score < 6) or self.frames_since_last_fruit == 250:
self.game_over()
def game_over(self):
"""Function that restarts the game upon game over.
This overrides the method in the SnakeGame superclass."""
self.snake = Snake(self.rows,self.cols)
self.generate_fruit()
self.restart = True
if self.score > self.high_score:
self.high_score = self.score
self.score = 0
self.frames_since_last_fruit = 0