-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #91 from UoA-CARES/76-carbeat
- Loading branch information
Showing
15 changed files
with
442 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import math | ||
|
||
import numpy as np | ||
import rclpy | ||
from rclpy import Future | ||
from sensor_msgs.msg import LaserScan | ||
|
||
from environment_interfaces.srv import Reset | ||
from environments.F1tenthEnvironment import F1tenthEnvironment | ||
from .termination import has_collided, has_flipped_over | ||
from .util import process_odom, reduce_lidar | ||
from .track_reset import track_info | ||
|
||
class CarBeatEnvironment(F1tenthEnvironment): | ||
""" | ||
CarTrack Reinforcement Learning Environment: | ||
Task: | ||
Here the agent learns to drive the f1tenth car to a goal position | ||
Observation: | ||
It's position (x, y), orientation (w, x, y, z), lidar points (approx. ~600 rays) and the goal's position (x, y) | ||
Action: | ||
It's linear and angular velocity | ||
Reward: | ||
It's progress toward the goal plus, | ||
50+ if it reaches the goal plus, | ||
-25 if it collides with the wall | ||
Termination Conditions: | ||
When the agent is within REWARD_RANGE units or, | ||
When the agent is within COLLISION_RANGE units | ||
Truncation Condition: | ||
When the number of steps surpasses MAX_STEPS | ||
""" | ||
|
||
def __init__(self, car_name, reward_range=1, max_steps=50, collision_range=0.2, step_length=0.5, track='track_1'): | ||
super().__init__('car_beat', car_name, max_steps, step_length) | ||
|
||
# Environment Details ---------------------------------------- | ||
self.MAX_STEPS_PER_GOAL = max_steps | ||
self.OBSERVATION_SIZE = 8 + 10 # Car position + Lidar rays | ||
self.COLLISION_RANGE = collision_range | ||
self.REWARD_RANGE = reward_range | ||
|
||
# Reset Client ----------------------------------------------- | ||
self.goal_number = 0 | ||
self.all_goals = track_info[track]['goals'] | ||
|
||
self.car_reset_positions = track_info[track]['reset'] | ||
|
||
self.get_logger().info('Environment Setup Complete') | ||
|
||
def reset(self): | ||
# self.get_logger().info('Environment Reset Called') | ||
|
||
self.step_counter = 0 | ||
|
||
self.set_velocity(0, 0) | ||
# self.get_logger().info('Velocity set') | ||
|
||
# TODO: Remove Hard coded-ness of 10x10 | ||
self.goal_number = 0 | ||
self.goal_position = self.generate_goal(self.goal_number) | ||
|
||
while not self.timer_future.done(): | ||
rclpy.spin_once(self) | ||
|
||
# self.get_logger().info('Sleep called') | ||
|
||
self.timer_future = Future() | ||
|
||
self.call_reset_service() | ||
|
||
# self.get_logger().info('Reset Called and returned') | ||
observation = self.get_observation() | ||
|
||
# self.get_logger().info('Observation returned') | ||
info = {} | ||
|
||
return observation, info | ||
|
||
def is_terminated(self, state): | ||
return has_collided(state[8:], self.COLLISION_RANGE) \ | ||
or has_flipped_over(state[2:6]) | ||
|
||
def generate_goal(self, number): | ||
print("Goal", number, "spawned") | ||
return self.all_goals[number % len(self.all_goals)] | ||
|
||
def call_reset_service(self): | ||
""" | ||
Reset the car and goal position | ||
""" | ||
|
||
x, y = self.goal_position | ||
|
||
request = Reset.Request() | ||
request.gx = x | ||
request.gy = y | ||
request.cx = self.car_reset_positions['x'] | ||
request.cy = self.car_reset_positions['y'] | ||
request.cyaw = self.car_reset_positions['yaw'] | ||
request.flag = "car_and_goal" | ||
|
||
future = self.reset_client.call_async(request) | ||
rclpy.spin_until_future_complete(self, future) | ||
|
||
return future.result() | ||
|
||
def update_goal_service(self, number): | ||
""" | ||
Reset the goal position | ||
""" | ||
|
||
x, y = self.generate_goal(number) | ||
self.goal_position = [x, y] | ||
|
||
request = Reset.Request() | ||
request.gx = x | ||
request.gy = y | ||
request.flag = "goal_only" | ||
|
||
future = self.reset_client.call_async(request) | ||
rclpy.spin_until_future_complete(self, future) | ||
|
||
return future.result() | ||
|
||
def get_observation(self): | ||
|
||
# Get Position and Orientation of F1tenth | ||
odom, lidar = self.get_data() | ||
odom = process_odom(odom) | ||
|
||
reduced_range = reduce_lidar(lidar) | ||
|
||
# Get Goal Position | ||
return odom + reduced_range | ||
|
||
def compute_reward(self, state, next_state): | ||
|
||
# TESTING ONLY | ||
|
||
# if self.goal_number < len(self.all_goals) - 1: | ||
# self.goal_number += 1 | ||
# else: | ||
# self.goal_number = 0 | ||
|
||
# self.update_goal_service(self.goal_number) | ||
# ============================================================== | ||
|
||
reward = 0 | ||
|
||
goal_position = self.goal_position | ||
|
||
prev_distance = math.dist(goal_position, state[:2]) | ||
current_distance = math.dist(goal_position, next_state[:2]) | ||
|
||
reward += prev_distance - current_distance | ||
|
||
if current_distance < self.REWARD_RANGE: | ||
reward += 50 | ||
self.goal_number += 1 | ||
self.step_counter = 0 | ||
self.update_goal_service(self.goal_number) | ||
|
||
if has_collided(next_state[8:], self.COLLISION_RANGE) or has_flipped_over(next_state[2:6]): | ||
reward -= 25 # TODO: find optimal value for this | ||
|
||
return reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import sys | ||
import rclpy | ||
from rclpy.node import Node | ||
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup | ||
from rclpy.executors import MultiThreadedExecutor | ||
|
||
from environment_interfaces.srv import Reset | ||
from f1tenth_control.SimulationServices import SimulationServices | ||
from ros_gz_interfaces.srv import SetEntityPose | ||
from ros_gz_interfaces.msg import Entity | ||
from geometry_msgs.msg import Pose, Point | ||
|
||
from ament_index_python import get_package_share_directory | ||
|
||
from .util import get_quaternion_from_euler | ||
|
||
class CarBeatReset(Node): | ||
def __init__(self): | ||
super().__init__('car_beat_reset') | ||
|
||
srv_cb_group = MutuallyExclusiveCallbackGroup() | ||
self.srv = self.create_service(Reset, 'car_beat_reset', callback=self.service_callback, callback_group=srv_cb_group) | ||
|
||
set_pose_cb_group = MutuallyExclusiveCallbackGroup() | ||
self.set_pose_client = self.create_client( | ||
SetEntityPose, | ||
f'world/empty/set_pose', | ||
callback_group=set_pose_cb_group | ||
) | ||
|
||
while not self.set_pose_client.wait_for_service(timeout_sec=1.0): | ||
self.get_logger().info('set_pose service not available, waiting again...') | ||
|
||
|
||
def service_callback(self, request, response): | ||
|
||
self.get_logger().info(f'Reset Service Request Received: relocating goal to x={request.cx} y={request.cy}') | ||
|
||
goal_req = self.create_request('goal', x=request.gx, y=request.gy, z=1) | ||
car_req = self.create_request('f1tenth_one', x=request.cx, y=request.cy, z=0.1, yaw=request.cyaw) | ||
|
||
while not self.set_pose_client.wait_for_service(timeout_sec=1.0): | ||
self.get_logger().info('set_pose service not available, waiting again...') | ||
|
||
#TODO: Call async and wait for both to execute | ||
if (request.flag == "goal_only"): | ||
self.set_pose_client.call(goal_req) | ||
else: | ||
self.set_pose_client.call(goal_req) | ||
self.set_pose_client.call(car_req) | ||
|
||
# self.get_logger().info('Successfully Reset') | ||
response.success = True | ||
|
||
return response | ||
|
||
def create_request(self, name, x=0, y=0, z=0, roll=0, pitch=0, yaw=0): | ||
req = SetEntityPose.Request() | ||
|
||
req.entity = Entity() | ||
req.entity.name = name | ||
req.entity.type = 2 # M | ||
|
||
req.pose = Pose() | ||
req.pose.position = Point() | ||
|
||
req.pose.position.x = float(x) | ||
req.pose.position.y = float(y) | ||
req.pose.position.z = float(z) | ||
|
||
orientation = get_quaternion_from_euler(roll, pitch, yaw) | ||
req.pose.orientation.x = orientation[0] | ||
req.pose.orientation.y = orientation[1] | ||
req.pose.orientation.z = orientation[2] | ||
req.pose.orientation.w = orientation[3] | ||
|
||
return req | ||
|
||
def main(): | ||
rclpy.init() | ||
pkg_environments = get_package_share_directory('environments') | ||
|
||
reset_service = CarBeatReset() | ||
pkg_environments = get_package_share_directory('environments') | ||
|
||
services = SimulationServices('empty') | ||
|
||
services.spawn(sdf_filename=f"{pkg_environments}/sdf/goal.sdf", pose=[1, 1, 1], name='goal') | ||
|
||
reset_service.get_logger().info('Environment Spawning Complete') | ||
|
||
executor = MultiThreadedExecutor() | ||
executor.add_node(reset_service) | ||
|
||
executor.spin() | ||
|
||
# rclpy.spin(reset_service) | ||
reset_service.destroy_node() | ||
rclpy.shutdown() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.