Skip to content

Commit

Permalink
Adding FastAI(pytorch) as a first class citizen to the pipeline (#982)
Browse files Browse the repository at this point in the history
* Squashing commits

* Fixed basic issues

* Refactored to use torch_data and removed specialised torch datasets.

* Added fastai to training tests

* Test will now do fastai test

I will need to remove the code around tflite and tensor rt for fastai.

* Added Fastai to the training tests and fixed issue with fastai implementation

* Fixing namedtuple to have correct number of args
  • Loading branch information
adricl authored Feb 23, 2022
1 parent a0eb5df commit 064b327
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 48 deletions.
287 changes: 268 additions & 19 deletions donkeycar/parts/fastai.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,280 @@
import os
from fastai.vision import *
"""
fastai.py
Methods to create, use, save and load pilots. Pilots contain the highlevel
logic used to determine the angle and throttle of a vehicle. Pilots can
include one or more models to help direct the vehicles motion.
"""
from abc import ABC, abstractmethod

import numpy as np
from pathlib import Path
from typing import Dict, Tuple, Optional, Union, List, Sequence, Callable
from logging import getLogger

import donkeycar as dk
import torch
from donkeycar.utils import normalize_image, linear_bin
from donkeycar.pipeline.types import TubRecord, TubDataset
from donkeycar.pipeline.sequence import TubSequence
from donkeycar.parts.interpreter import FastAIInterpreter, Interpreter, KerasInterpreter
from donkeycar.parts.pytorch.torch_data import TorchTubDataset, get_default_transform

class FastAiPilot(object):
from fastai.vision.all import *
from fastai.data.transforms import *
from fastai import optimizer as fastai_optimizer
from torch.utils.data import IterableDataset, DataLoader

def __init__(self):
self.learn = None
from torchvision import transforms

ONE_BYTE_SCALE = 1.0 / 255.0

# type of x
XY = Union[float, np.ndarray, Tuple[Union[float, np.ndarray], ...]]

logger = getLogger(__name__)

class FastAiPilot(ABC):
"""
Base class for Fast AI models that will provide steering and throttle to
guide a car.
"""

def __init__(self,
interpreter: Interpreter = FastAIInterpreter(),
input_shape: Tuple[int, ...] = (120, 160, 3)) -> None:
self.model: Optional[Model] = None
self.input_shape = input_shape
self.optimizer = "adam"
self.interpreter = interpreter
self.interpreter.set_model(self)
self.learner = None
logger.info(f'Created {self} with interpreter: {interpreter}')

def load(self, model_path):
if torch.cuda.is_available():
print("using cuda for torch inference")
defaults.device = torch.device('cuda')
logger.info(f'Loading model {model_path}')
self.interpreter.load(model_path)

def load_weights(self, model_path: str, by_name: bool = True) -> None:
self.interpreter.load_weights(model_path, by_name=by_name)

def shutdown(self) -> None:
pass

def compile(self) -> None:
pass

@abstractmethod
def create_model(self):
pass

def set_optimizer(self, optimizer_type: str,
rate: float, decay: float) -> None:
if optimizer_type == "adam":
optimizer = fastai_optimizer.Adam(lr=rate, wd=decay)
elif optimizer_type == "sgd":
optimizer = fastai_optimizer.SGD(lr=rate, wd=decay)
elif optimizer_type == "rmsprop":
optimizer = fastai_optimizer.RMSprop(lr=rate, wd=decay)
else:
print("cuda not available for torch inference")
raise Exception(f"Unknown optimizer type: {optimizer_type}")
self.interpreter.set_optimizer(optimizer)

# shape
def get_input_shapes(self):
return self.interpreter.get_input_shapes()

def seq_size(self) -> int:
return 0

def run(self, img_arr: np.ndarray, other_arr: List[float] = None) \
-> Tuple[Union[float, torch.tensor], ...]:
"""
Donkeycar parts interface to run the part in the loop.
:param img_arr: uint8 [0,255] numpy array with image data
:param other_arr: numpy array of additional data to be used in the
pilot, like IMU array for the IMU model or a
state vector in the Behavioural model
:return: tuple of (angle, throttle)
"""
transform = get_default_transform(resize=False)
norm_arr = transform(img_arr)
tensor_other_array = torch.FloatTensor(other_arr) if other_arr else None
return self.inference(norm_arr, tensor_other_array)

def inference(self, img_arr: torch.tensor, other_arr: Optional[torch.tensor]) \
-> Tuple[Union[float, torch.tensor], ...]:
""" Inferencing using the interpreter
:param img_arr: float32 [0,1] numpy array with normalized image
data
:param other_arr: tensor array of additional data to be used in the
pilot, like IMU array for the IMU model or a
state vector in the Behavioural model
:return: tuple of (angle, throttle)
"""
out = self.interpreter.predict(img_arr, other_arr)
return self.interpreter_to_output(out)

def inference_from_dict(self, input_dict: Dict[str, np.ndarray]) \
-> Tuple[Union[float, np.ndarray], ...]:
""" Inferencing using the interpreter
:param input_dict: input dictionary of str and np.ndarray
:return: typically tuple of (angle, throttle)
"""
output = self.interpreter.predict_from_dict(input_dict)
return self.interpreter_to_output(output)

@abstractmethod
def interpreter_to_output(
self,
interpreter_out: Sequence[Union[float, np.ndarray]]) \
-> Tuple[Union[float, np.ndarray], ...]:
""" Virtual method to be implemented by child classes for conversion
:param interpreter_out: input data
:return: output values, possibly tuple of np.ndarray
"""
pass

path = os.path.dirname(model_path)
fname = os.path.basename(model_path)
self.learn = load_learner(path=path, file=fname)
def train(self,
model_path: str,
train_data: TorchTubDataset,
train_steps: int,
batch_size: int,
validation_data: TorchTubDataset,
validation_steps: int,
epochs: int,
verbose: int = 1,
min_delta: float = .0005,
patience: int = 5,
show_plot: bool = False):
"""
trains the model
"""
assert isinstance(self.interpreter, FastAIInterpreter)
model = self.interpreter.model

def run(self, img):
t = pil2tensor(img, dtype=np.float32) # converts to tensor
im = Image(t) # Convert to fastAi Image - this class has "apply_tfms"
dataLoader = DataLoaders.from_dsets(train_data, validation_data, bs=batch_size, shuffle=False)
if torch.cuda.is_available():
dataLoader.cuda()

#dataLoaderTest = self.dataBlock.dataloaders.test_dl(validation_data, with_labels=True)
#print(dataLoader.train[0])

callbacks = [
EarlyStoppingCallback(monitor='valid_loss',
patience=patience,
min_delta=min_delta),
SaveModelCallback(monitor='valid_loss',
every_epoch=False
)
]

self.learner = Learner(dataLoader, model, loss_func=self.loss, path=Path(model_path).parent)

logger.info(self.learner.summary())
logger.info(self.learner.loss_func)

lr_result = self.learner.lr_find()
suggestedLr = float(lr_result[0])

logger.info(f"Suggested Learning Rate {suggestedLr}")

self.learner.fit_one_cycle(epochs, suggestedLr, cbs=callbacks)

torch.save(self.learner.model, model_path)

if show_plot:
self.learner.recorder.plot_loss()
plt.savefig(Path(model_path).with_suffix('.png'))

history = { "loss" : list(map((lambda x: x.item()), self.learner.recorder.losses)) }
return history

def __str__(self) -> str:
""" For printing model initialisation """
return type(self).__name__

pred = self.learn.predict(im)
steering = float(pred[0].data[0])
throttle = float(pred[0].data[1])
class FastAILinear(FastAiPilot):
"""
The KerasLinear pilot uses one neuron to output a continuous value via
the Keras Dense layer with linear activation. One each for steering and
throttle. The output is not bounded.
"""

def __init__(self,
interpreter: Interpreter = FastAIInterpreter(),
input_shape: Tuple[int, ...] = (120, 160, 3),
num_outputs: int = 2):
self.num_outputs = num_outputs
self.loss = MSELossFlat()

super().__init__(interpreter, input_shape)

def create_model(self):
return Linear()

def compile(self):
self.optimizer = self.optimizer
self.loss = 'mse'

def interpreter_to_output(self, interpreter_out):
interpreter_out = (interpreter_out * 2) - 1
steering = interpreter_out[0]
throttle = interpreter_out[1]
return steering, throttle


def y_transform(self, record: Union[TubRecord, List[TubRecord]]) -> XY:
assert isinstance(record, TubRecord), 'TubRecord expected'
angle: float = record.underlying['user/angle']
throttle: float = record.underlying['user/throttle']
return angle, throttle

def y_translate(self, y: XY) -> Dict[str, Union[float, List[float]]]:
assert isinstance(y, tuple), 'Expected tuple'
angle, throttle = y
return {'n_outputs0': angle, 'n_outputs1': throttle}

def output_shapes(self):
# need to cut off None from [None, 120, 160, 3] tensor shape
img_shape = self.get_input_shapes()[0][1:]


class Linear(nn.Module):
def __init__(self):
super(Linear, self).__init__()
self.dropout = 0.1
# init the layers
self.conv24 = nn.Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))
self.conv32 = nn.Conv2d(24, 32, kernel_size=(5, 5), stride=(2, 2))
self.conv64_5 = nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2))
self.conv64_3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
self.fc1 = nn.Linear(6656, 100)
self.fc2 = nn.Linear(100, 50)
self.drop = nn.Dropout(self.dropout)
self.relu = nn.ReLU()
self.output1 = nn.Linear(50, 1)
self.output2 = nn.Linear(50, 1)
self.flatten = nn.Flatten()

def forward(self, x):
x = self.relu(self.conv24(x))
x = self.drop(x)
x = self.relu(self.conv32(x))
x = self.drop(x)
x = self.relu(self.conv64_5(x))
x = self.drop(x)
x = self.relu(self.conv64_3(x))
x = self.drop(x)
x = self.relu(self.conv64_3(x))
x = self.drop(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.drop(x)
x = self.fc2(x)
x1 = self.drop(x)
angle = self.output1(x1)
throttle = self.output2(x1)
return torch.cat((angle, throttle), 1)
64 changes: 63 additions & 1 deletion donkeycar/parts/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import Union, Sequence, List

import tensorflow as tf
import torch
from tensorflow import keras

from tensorflow.python.framework.convert_to_constants import \
convert_variables_to_constants_v2 as convert_var_to_const
from tensorflow.python.saved_model import tag_constants, signature_constants


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -105,6 +105,9 @@ def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \
def predict_from_dict(self, input_dict) -> Sequence[Union[float, np.ndarray]]:
pass

def summary(self) -> str:
pass

def __str__(self) -> str:
""" For printing interpreter """
return type(self).__name__
Expand Down Expand Up @@ -165,6 +168,65 @@ def load_weights(self, model_path: str, by_name: bool = True) -> \
assert self.model, 'Model not set'
self.model.load_weights(model_path, by_name=by_name)

def summary(self) -> str:
return self.model.summary()

class FastAIInterpreter(Interpreter):

def __init__(self):
super().__init__()
self.model: None
from fastai import learner as fastai_learner
from fastai import optimizer as fastai_optimizer

def set_model(self, pilot: 'FastAiPilot') -> None:
self.model = pilot.create_model()

def set_optimizer(self, optimizer: 'fastai_optimizer') -> None:
self.model.optimizer = optimizer

def get_input_shapes(self):
assert self.model, 'Model not set'
return [inp.shape for inp in self.model.inputs]

def compile(self, **kwargs):
pass

def invoke(self, inputs):
outputs = self.model(inputs)
# for functional models the output here is a list
if type(outputs) is list:
# as we invoke the interpreter with a batch size of one we remove
# the additional dimension here again
output = [output.numpy().squeeze(axis=0) for output in outputs]
return output
# for sequential models the output shape is (1, n) with n = output dim
else:
return outputs.detach().numpy().squeeze(axis=0)

def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \
-> Sequence[Union[float, np.ndarray]]:

inputs = torch.unsqueeze(img_arr, 0)
if other_arr is not None:
#other_arr = np.expand_dims(other_arr, axis=0)
inputs = [img_arr, other_arr]
return self.invoke(inputs)

def load(self, model_path: str) -> None:
logger.info(f'Loading model {model_path}')
if torch.cuda.is_available():
logger.info("using cuda for torch inference")
self.model = torch.load(model_path)
else:
logger.info("cuda not available for torch inference")
self.model = torch.load(model_path, map_location=torch.device('cpu'))

logger.info(self.model)
self.model.eval()

def summary(self) -> str:
return self.model

class TfLite(Interpreter):
"""
Expand Down
Loading

0 comments on commit 064b327

Please sign in to comment.