Skip to content

Commit

Permalink
Merge pull request #94 from UoA-CARES/76-carbeat
Browse files Browse the repository at this point in the history
  • Loading branch information
retinfai authored Aug 2, 2023
2 parents 40cb8d6 + c570b24 commit 085739d
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 74 deletions.
2 changes: 1 addition & 1 deletion src/controllers/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sensor_msgs.msg import LaserScan
from nav_msgs.msg import Odometry

from environments.utils import process_lidar, process_odom
from environments.util import process_lidar, process_odom

class Controller(Node):
def __init__(self, node_name, car_name, step_length):
Expand Down
2 changes: 2 additions & 0 deletions src/controllers/controllers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import random

def main():
rclpy.init()

param_node = rclpy.create_node('params')

param_node.declare_parameters(
Expand Down
1 change: 1 addition & 0 deletions src/environment_interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ find_package(rosidl_default_generators REQUIRED)

rosidl_generate_interfaces(${PROJECT_NAME}
"srv/Reset.srv"
"srv/CarBeatReset.srv"
)

ament_package()
Expand Down
11 changes: 11 additions & 0 deletions src/environment_interfaces/srv/CarBeatReset.srv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
float64 gx # goal x
float64 gy # goal y
float64 cx_one # car_one x
float64 cy_one # car_one y
float64 cyaw_one # car_one yaw
float64 cx_two # car_two x
float64 cy_two # car_two y
float64 cyaw_two # car_two yaw
string flag # flag
---
bool success
210 changes: 152 additions & 58 deletions src/environments/environments/CarBeatEnvironment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,109 @@
import rclpy
from rclpy import Future
from sensor_msgs.msg import LaserScan
from launch_ros.actions import Node

from environment_interfaces.srv import Reset
import numpy as np
import rclpy
from geometry_msgs.msg import Twist
from message_filters import Subscriber, ApproximateTimeSynchronizer
from rclpy import Future
from rclpy.node import Node
from sensor_msgs.msg import LaserScan
from nav_msgs.msg import Odometry

from environment_interfaces.srv import CarBeatReset
from environments.F1tenthEnvironment import F1tenthEnvironment
from .termination import has_collided, has_flipped_over
from .util import process_odom, reduce_lidar
from .track_reset import track_info

class CarBeatEnvironment(F1tenthEnvironment):
"""
CarTrack Reinforcement Learning Environment:
Task:
Here the agent learns to drive the f1tenth car to a goal position
class CarBeatEnvironment(Node):

Observation:
It's position (x, y), orientation (w, x, y, z), lidar points (approx. ~600 rays) and the goal's position (x, y)
Action:
It's linear and angular velocity
Reward:
It's progress toward the goal plus,
50+ if it reaches the goal plus,
-25 if it collides with the wall
Termination Conditions:
When the agent is within REWARD_RANGE units or,
When the agent is within COLLISION_RANGE units
Truncation Condition:
When the number of steps surpasses MAX_STEPS
"""

def __init__(self, car_name, reward_range=1, max_steps=50, collision_range=0.2, step_length=0.5, track='track_1'):
super().__init__('car_beat', car_name, max_steps, step_length)
def __init__(self, car_one_name, car_two_name, reward_range=1, max_steps=50, collision_range=0.2, step_length=0.5, track='track_1'):
super().__init__('car_beat_environment')

# Environment Details ----------------------------------------
self.NAME = car_one_name
self.OTHER_CAR_NAME = car_two_name
self.MAX_STEPS = max_steps
self.STEP_LENGTH = step_length
self.MAX_ACTIONS = np.asarray([3, 3.14])
self.MIN_ACTIONS = np.asarray([0, -3.14])
self.MAX_STEPS_PER_GOAL = max_steps

# TODO: Update this
self.OBSERVATION_SIZE = 8 + 10 # Car position + Lidar rays
self.COLLISION_RANGE = collision_range
self.REWARD_RANGE = reward_range
self.ACTION_NUM = 2

# Reset Client -----------------------------------------------
self.step_counter = 0

# Goal/Track Info -----------------------------------------------
self.goal_number = 0
self.all_goals = track_info[track]['goals']

self.car_reset_positions = track_info[track]['reset']

self.get_logger().info('Environment Setup Complete')
# Pub/Sub ----------------------------------------------------
self.cmd_vel_pub = self.create_publisher(
Twist,
f'/{self.NAME}/cmd_vel',
10
)

self.odom_sub_one = Subscriber(
self,
Odometry,
f'/{self.NAME}/odometry',
)

self.lidar_sub_one = Subscriber(
self,
LaserScan,
f'/{self.NAME}/scan',
)

self.odom_sub_two = Subscriber(
self,
Odometry,
f'/{self.OTHER_CAR_NAME}/odometry',
)

self.lidar_sub_two = Subscriber(
self,
LaserScan,
f'/{self.OTHER_CAR_NAME}/scan',
)

self.message_filter = ApproximateTimeSynchronizer(
[self.odom_sub_one, self.lidar_sub_one, self.odom_sub_two, self.lidar_sub_two],
10,
0.1,
)

self.message_filter.registerCallback(self.message_filter_callback)

self.observation_future = Future()

# Reset Client -----------------------------------------------
self.reset_client = self.create_client(
CarBeatReset,
'car_beat_reset'
)

while not self.reset_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('reset service not available, waiting again...')

self.timer = self.create_timer(step_length, self.timer_cb)
self.timer_future = Future()

def reset(self):
# self.get_logger().info('Environment Reset Called')


self.step_counter = 0

self.set_velocity(0, 0)
# self.get_logger().info('Velocity set')

# TODO: Remove Hard coded-ness of 10x10
self.goal_number = 0
Expand All @@ -69,22 +115,66 @@ def reset(self):
while not self.timer_future.done():
rclpy.spin_once(self)

# self.get_logger().info('Sleep called')

self.timer_future = Future()

self.call_reset_service()

# self.get_logger().info('Reset Called and returned')
observation = self.get_observation()

# self.get_logger().info('Observation returned')
info = {}

return observation, info

def step(self, action):
self.step_counter += 1

state = self.get_observation()

lin_vel, ang_vel = action
self.set_velocity(lin_vel, ang_vel)

while not self.timer_future.done():
rclpy.spin_once(self)

self.timer_future = Future()

next_state = self.get_observation()
reward = self.compute_reward(state, next_state)
terminated = self.is_terminated(next_state)
truncated = self.step_counter >= self.MAX_STEPS
info = {}

return next_state, reward, terminated, truncated, info

def message_filter_callback(self, odom_one: Odometry, lidar_one: LaserScan, odom_two: Odometry, lidar_two: LaserScan):
self.observation_future.set_result({'odom_one': odom_one, 'lidar_one': lidar_one, 'odom_two': odom_two, 'lidar_two': lidar_two})

def get_data(self):
rclpy.spin_until_future_complete(self, self.observation_future)
future = self.observation_future
self.observation_future = Future()
data = future.result()
return data['odom_one'], data['lidar_one'], data['odom_two'], data['lidar_two']

def set_velocity(self, linear, angular):
"""
Publish Twist messages to f1tenth cmd_vel topic
"""
velocity_msg = Twist()
velocity_msg.angular.z = float(angular)
velocity_msg.linear.x = float(linear)

self.cmd_vel_pub.publish(velocity_msg)

def sleep(self):
while not self.timer_future.done():
rclpy.spin_once(self)

def timer_cb(self):
self.timer_future.set_result(True)

def is_terminated(self, state):
return has_collided(state[8:], self.COLLISION_RANGE) \
return has_collided(state[8:19], self.COLLISION_RANGE) \
or has_flipped_over(state[2:6])

def generate_goal(self, number):
Expand All @@ -98,12 +188,21 @@ def call_reset_service(self):

x, y = self.goal_position

request = Reset.Request()
request = CarBeatReset.Request()

request.gx = x
request.gy = y
request.cx = self.car_reset_positions['x']
request.cy = self.car_reset_positions['y']
request.cyaw = self.car_reset_positions['yaw']

request.cx_one = self.car_reset_positions['x']
request.cy_one = self.car_reset_positions['y']
request.cyaw_one = self.car_reset_positions['yaw']

request.cx_two = self.all_goals[0][0]
request.cy_two = self.all_goals[0][1]

# TODO: Fix this
request.cyaw_two = self.car_reset_positions['yaw']

request.flag = "car_and_goal"

future = self.reset_client.call_async(request)
Expand All @@ -119,7 +218,7 @@ def update_goal_service(self, number):
x, y = self.generate_goal(number)
self.goal_position = [x, y]

request = Reset.Request()
request = CarBeatReset.Request()
request.gx = x
request.gy = y
request.flag = "goal_only"
Expand All @@ -132,25 +231,20 @@ def update_goal_service(self, number):
def get_observation(self):

# Get Position and Orientation of F1tenth
odom, lidar = self.get_data()
odom = process_odom(odom)
odom_one, lidar_one, odom_two, lidar_two = self.get_data()

reduced_range = reduce_lidar(lidar)
odom_one = process_odom(odom_one)
odom_two = process_odom(odom_two)

# Get Goal Position
return odom + reduced_range
lidar_one = reduce_lidar(lidar_one)
lidar_two = reduce_lidar(lidar_two)

def compute_reward(self, state, next_state):

# TESTING ONLY
self.get_logger().info(
f'odom_one: {odom_one} \n\nlidar_one: {lidar_one} \n\nodom_two: {odom_two} \n\nlidar_two: {lidar_two}')

# if self.goal_number < len(self.all_goals) - 1:
# self.goal_number += 1
# else:
# self.goal_number = 0
return odom_one + lidar_one + odom_two + lidar_two + self.goal_position

# self.update_goal_service(self.goal_number)
# ==============================================================
def compute_reward(self, state, next_state):

reward = 0

Expand All @@ -167,7 +261,7 @@ def compute_reward(self, state, next_state):
self.step_counter = 0
self.update_goal_service(self.goal_number)

if has_collided(next_state[8:], self.COLLISION_RANGE) or has_flipped_over(next_state[2:6]):
if has_collided(next_state[8:19], self.COLLISION_RANGE) or has_flipped_over(next_state[2:6]):
reward -= 25 # TODO: find optimal value for this

return reward
13 changes: 8 additions & 5 deletions src/environments/environments/CarBeatReset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor

from environment_interfaces.srv import Reset
from environment_interfaces.srv import CarBeatReset as CarBeatResetSrv
from f1tenth_control.SimulationServices import SimulationServices
from ros_gz_interfaces.srv import SetEntityPose
from ros_gz_interfaces.msg import Entity
Expand All @@ -19,7 +19,7 @@ def __init__(self):
super().__init__('car_beat_reset')

srv_cb_group = MutuallyExclusiveCallbackGroup()
self.srv = self.create_service(Reset, 'car_beat_reset', callback=self.service_callback, callback_group=srv_cb_group)
self.srv = self.create_service(CarBeatResetSrv, 'car_beat_reset', callback=self.service_callback, callback_group=srv_cb_group)

set_pose_cb_group = MutuallyExclusiveCallbackGroup()
self.set_pose_client = self.create_client(
Expand All @@ -34,10 +34,11 @@ def __init__(self):

def service_callback(self, request, response):

self.get_logger().info(f'Reset Service Request Received: relocating goal to x={request.cx} y={request.cy}')
self.get_logger().info(f'Reset Service Request Received: relocating goal to x={request.cx_one} y={request.cy_one}')

goal_req = self.create_request('goal', x=request.gx, y=request.gy, z=1)
car_req = self.create_request('f1tenth_one', x=request.cx, y=request.cy, z=0.1, yaw=request.cyaw)
car_req_one = self.create_request('f1tenth_one', x=request.cx_one, y=request.cy_one, z=0.1, yaw=request.cyaw_one)
car_req_two = self.create_request('f1tenth_two', x=request.cx_two, y=request.cy_two, z=0.1, yaw=request.cyaw_two)

while not self.set_pose_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('set_pose service not available, waiting again...')
Expand All @@ -47,7 +48,9 @@ def service_callback(self, request, response):
self.set_pose_client.call(goal_req)
else:
self.set_pose_client.call(goal_req)
self.set_pose_client.call(car_req)
self.set_pose_client.call(car_req_one)
self.set_pose_client.call(car_req_two)


# self.get_logger().info('Successfully Reset')
response.success = True
Expand Down
Loading

0 comments on commit 085739d

Please sign in to comment.