diff --git a/donkeycar/parts/fastai.py b/donkeycar/parts/fastai.py index 1ab0769e7..4b3d59b32 100644 --- a/donkeycar/parts/fastai.py +++ b/donkeycar/parts/fastai.py @@ -1,31 +1,280 @@ -import os -from fastai.vision import * +""" + +fastai.py + +Methods to create, use, save and load pilots. Pilots contain the highlevel +logic used to determine the angle and throttle of a vehicle. Pilots can +include one or more models to help direct the vehicles motion. + +""" +from abc import ABC, abstractmethod + +import numpy as np +from pathlib import Path +from typing import Dict, Tuple, Optional, Union, List, Sequence, Callable +from logging import getLogger + +import donkeycar as dk import torch +from donkeycar.utils import normalize_image, linear_bin +from donkeycar.pipeline.types import TubRecord, TubDataset +from donkeycar.pipeline.sequence import TubSequence +from donkeycar.parts.interpreter import FastAIInterpreter, Interpreter, KerasInterpreter +from donkeycar.parts.pytorch.torch_data import TorchTubDataset, get_default_transform -class FastAiPilot(object): +from fastai.vision.all import * +from fastai.data.transforms import * +from fastai import optimizer as fastai_optimizer +from torch.utils.data import IterableDataset, DataLoader - def __init__(self): - self.learn = None +from torchvision import transforms + +ONE_BYTE_SCALE = 1.0 / 255.0 + +# type of x +XY = Union[float, np.ndarray, Tuple[Union[float, np.ndarray], ...]] + +logger = getLogger(__name__) + +class FastAiPilot(ABC): + """ + Base class for Fast AI models that will provide steering and throttle to + guide a car. + """ + + def __init__(self, + interpreter: Interpreter = FastAIInterpreter(), + input_shape: Tuple[int, ...] = (120, 160, 3)) -> None: + self.model: Optional[Model] = None + self.input_shape = input_shape + self.optimizer = "adam" + self.interpreter = interpreter + self.interpreter.set_model(self) + self.learner = None + logger.info(f'Created {self} with interpreter: {interpreter}') def load(self, model_path): - if torch.cuda.is_available(): - print("using cuda for torch inference") - defaults.device = torch.device('cuda') + logger.info(f'Loading model {model_path}') + self.interpreter.load(model_path) + + def load_weights(self, model_path: str, by_name: bool = True) -> None: + self.interpreter.load_weights(model_path, by_name=by_name) + + def shutdown(self) -> None: + pass + + def compile(self) -> None: + pass + + @abstractmethod + def create_model(self): + pass + + def set_optimizer(self, optimizer_type: str, + rate: float, decay: float) -> None: + if optimizer_type == "adam": + optimizer = fastai_optimizer.Adam(lr=rate, wd=decay) + elif optimizer_type == "sgd": + optimizer = fastai_optimizer.SGD(lr=rate, wd=decay) + elif optimizer_type == "rmsprop": + optimizer = fastai_optimizer.RMSprop(lr=rate, wd=decay) else: - print("cuda not available for torch inference") + raise Exception(f"Unknown optimizer type: {optimizer_type}") + self.interpreter.set_optimizer(optimizer) + + # shape + def get_input_shapes(self): + return self.interpreter.get_input_shapes() + + def seq_size(self) -> int: + return 0 + + def run(self, img_arr: np.ndarray, other_arr: List[float] = None) \ + -> Tuple[Union[float, torch.tensor], ...]: + """ + Donkeycar parts interface to run the part in the loop. + + :param img_arr: uint8 [0,255] numpy array with image data + :param other_arr: numpy array of additional data to be used in the + pilot, like IMU array for the IMU model or a + state vector in the Behavioural model + :return: tuple of (angle, throttle) + """ + transform = get_default_transform(resize=False) + norm_arr = transform(img_arr) + tensor_other_array = torch.FloatTensor(other_arr) if other_arr else None + return self.inference(norm_arr, tensor_other_array) + + def inference(self, img_arr: torch.tensor, other_arr: Optional[torch.tensor]) \ + -> Tuple[Union[float, torch.tensor], ...]: + """ Inferencing using the interpreter + :param img_arr: float32 [0,1] numpy array with normalized image + data + :param other_arr: tensor array of additional data to be used in the + pilot, like IMU array for the IMU model or a + state vector in the Behavioural model + :return: tuple of (angle, throttle) + """ + out = self.interpreter.predict(img_arr, other_arr) + return self.interpreter_to_output(out) + + def inference_from_dict(self, input_dict: Dict[str, np.ndarray]) \ + -> Tuple[Union[float, np.ndarray], ...]: + """ Inferencing using the interpreter + :param input_dict: input dictionary of str and np.ndarray + :return: typically tuple of (angle, throttle) + """ + output = self.interpreter.predict_from_dict(input_dict) + return self.interpreter_to_output(output) + @abstractmethod + def interpreter_to_output( + self, + interpreter_out: Sequence[Union[float, np.ndarray]]) \ + -> Tuple[Union[float, np.ndarray], ...]: + """ Virtual method to be implemented by child classes for conversion + :param interpreter_out: input data + :return: output values, possibly tuple of np.ndarray + """ + pass - path = os.path.dirname(model_path) - fname = os.path.basename(model_path) - self.learn = load_learner(path=path, file=fname) + def train(self, + model_path: str, + train_data: TorchTubDataset, + train_steps: int, + batch_size: int, + validation_data: TorchTubDataset, + validation_steps: int, + epochs: int, + verbose: int = 1, + min_delta: float = .0005, + patience: int = 5, + show_plot: bool = False): + """ + trains the model + """ + assert isinstance(self.interpreter, FastAIInterpreter) + model = self.interpreter.model - def run(self, img): - t = pil2tensor(img, dtype=np.float32) # converts to tensor - im = Image(t) # Convert to fastAi Image - this class has "apply_tfms" + dataLoader = DataLoaders.from_dsets(train_data, validation_data, bs=batch_size, shuffle=False) + if torch.cuda.is_available(): + dataLoader.cuda() + + #dataLoaderTest = self.dataBlock.dataloaders.test_dl(validation_data, with_labels=True) + #print(dataLoader.train[0]) + + callbacks = [ + EarlyStoppingCallback(monitor='valid_loss', + patience=patience, + min_delta=min_delta), + SaveModelCallback(monitor='valid_loss', + every_epoch=False + ) + ] + + self.learner = Learner(dataLoader, model, loss_func=self.loss, path=Path(model_path).parent) + + logger.info(self.learner.summary()) + logger.info(self.learner.loss_func) + + lr_result = self.learner.lr_find() + suggestedLr = float(lr_result[0]) + + logger.info(f"Suggested Learning Rate {suggestedLr}") + + self.learner.fit_one_cycle(epochs, suggestedLr, cbs=callbacks) + + torch.save(self.learner.model, model_path) + + if show_plot: + self.learner.recorder.plot_loss() + plt.savefig(Path(model_path).with_suffix('.png')) + + history = { "loss" : list(map((lambda x: x.item()), self.learner.recorder.losses)) } + return history + + def __str__(self) -> str: + """ For printing model initialisation """ + return type(self).__name__ - pred = self.learn.predict(im) - steering = float(pred[0].data[0]) - throttle = float(pred[0].data[1]) +class FastAILinear(FastAiPilot): + """ + The KerasLinear pilot uses one neuron to output a continuous value via + the Keras Dense layer with linear activation. One each for steering and + throttle. The output is not bounded. + """ + def __init__(self, + interpreter: Interpreter = FastAIInterpreter(), + input_shape: Tuple[int, ...] = (120, 160, 3), + num_outputs: int = 2): + self.num_outputs = num_outputs + self.loss = MSELossFlat() + + super().__init__(interpreter, input_shape) + + def create_model(self): + return Linear() + + def compile(self): + self.optimizer = self.optimizer + self.loss = 'mse' + + def interpreter_to_output(self, interpreter_out): + interpreter_out = (interpreter_out * 2) - 1 + steering = interpreter_out[0] + throttle = interpreter_out[1] return steering, throttle - \ No newline at end of file + + def y_transform(self, record: Union[TubRecord, List[TubRecord]]) -> XY: + assert isinstance(record, TubRecord), 'TubRecord expected' + angle: float = record.underlying['user/angle'] + throttle: float = record.underlying['user/throttle'] + return angle, throttle + + def y_translate(self, y: XY) -> Dict[str, Union[float, List[float]]]: + assert isinstance(y, tuple), 'Expected tuple' + angle, throttle = y + return {'n_outputs0': angle, 'n_outputs1': throttle} + + def output_shapes(self): + # need to cut off None from [None, 120, 160, 3] tensor shape + img_shape = self.get_input_shapes()[0][1:] + + +class Linear(nn.Module): + def __init__(self): + super(Linear, self).__init__() + self.dropout = 0.1 + # init the layers + self.conv24 = nn.Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2)) + self.conv32 = nn.Conv2d(24, 32, kernel_size=(5, 5), stride=(2, 2)) + self.conv64_5 = nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2)) + self.conv64_3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) + self.fc1 = nn.Linear(6656, 100) + self.fc2 = nn.Linear(100, 50) + self.drop = nn.Dropout(self.dropout) + self.relu = nn.ReLU() + self.output1 = nn.Linear(50, 1) + self.output2 = nn.Linear(50, 1) + self.flatten = nn.Flatten() + + def forward(self, x): + x = self.relu(self.conv24(x)) + x = self.drop(x) + x = self.relu(self.conv32(x)) + x = self.drop(x) + x = self.relu(self.conv64_5(x)) + x = self.drop(x) + x = self.relu(self.conv64_3(x)) + x = self.drop(x) + x = self.relu(self.conv64_3(x)) + x = self.drop(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.drop(x) + x = self.fc2(x) + x1 = self.drop(x) + angle = self.output1(x1) + throttle = self.output2(x1) + return torch.cat((angle, throttle), 1) diff --git a/donkeycar/parts/interpreter.py b/donkeycar/parts/interpreter.py index 2a3ce4292..a031e903c 100755 --- a/donkeycar/parts/interpreter.py +++ b/donkeycar/parts/interpreter.py @@ -5,13 +5,13 @@ from typing import Union, Sequence, List import tensorflow as tf +import torch from tensorflow import keras from tensorflow.python.framework.convert_to_constants import \ convert_variables_to_constants_v2 as convert_var_to_const from tensorflow.python.saved_model import tag_constants, signature_constants - logger = logging.getLogger(__name__) @@ -105,6 +105,9 @@ def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \ def predict_from_dict(self, input_dict) -> Sequence[Union[float, np.ndarray]]: pass + def summary(self) -> str: + pass + def __str__(self) -> str: """ For printing interpreter """ return type(self).__name__ @@ -165,6 +168,65 @@ def load_weights(self, model_path: str, by_name: bool = True) -> \ assert self.model, 'Model not set' self.model.load_weights(model_path, by_name=by_name) + def summary(self) -> str: + return self.model.summary() + +class FastAIInterpreter(Interpreter): + + def __init__(self): + super().__init__() + self.model: None + from fastai import learner as fastai_learner + from fastai import optimizer as fastai_optimizer + + def set_model(self, pilot: 'FastAiPilot') -> None: + self.model = pilot.create_model() + + def set_optimizer(self, optimizer: 'fastai_optimizer') -> None: + self.model.optimizer = optimizer + + def get_input_shapes(self): + assert self.model, 'Model not set' + return [inp.shape for inp in self.model.inputs] + + def compile(self, **kwargs): + pass + + def invoke(self, inputs): + outputs = self.model(inputs) + # for functional models the output here is a list + if type(outputs) is list: + # as we invoke the interpreter with a batch size of one we remove + # the additional dimension here again + output = [output.numpy().squeeze(axis=0) for output in outputs] + return output + # for sequential models the output shape is (1, n) with n = output dim + else: + return outputs.detach().numpy().squeeze(axis=0) + + def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \ + -> Sequence[Union[float, np.ndarray]]: + + inputs = torch.unsqueeze(img_arr, 0) + if other_arr is not None: + #other_arr = np.expand_dims(other_arr, axis=0) + inputs = [img_arr, other_arr] + return self.invoke(inputs) + + def load(self, model_path: str) -> None: + logger.info(f'Loading model {model_path}') + if torch.cuda.is_available(): + logger.info("using cuda for torch inference") + self.model = torch.load(model_path) + else: + logger.info("cuda not available for torch inference") + self.model = torch.load(model_path, map_location=torch.device('cpu')) + + logger.info(self.model) + self.model.eval() + + def summary(self) -> str: + return self.model class TfLite(Interpreter): """ diff --git a/donkeycar/parts/keras.py b/donkeycar/parts/keras.py index 686bde617..a58ca76e6 100644 --- a/donkeycar/parts/keras.py +++ b/donkeycar/parts/keras.py @@ -216,7 +216,7 @@ def train(self, except Exception as ex: print(f"problems with loss graph: {ex}") - return history + return history.history def x_transform(self, record: Union[TubRecord, List[TubRecord]]) -> XY: """ Return x from record, default returns only image array""" diff --git a/donkeycar/parts/pytorch/torch_data.py b/donkeycar/parts/pytorch/torch_data.py index 09292de82..e02439282 100644 --- a/donkeycar/parts/pytorch/torch_data.py +++ b/donkeycar/parts/pytorch/torch_data.py @@ -10,7 +10,7 @@ import pytorch_lightning as pl -def get_default_transform(for_video=False, for_inference=False): +def get_default_transform(for_video=False, for_inference=False, resize=True): """ Creates a default transform to work with torchvision models @@ -32,13 +32,15 @@ def get_default_transform(for_video=False, for_inference=False): std = [0.22803, 0.22145, 0.216989] input_size = (112, 112) - transform = transforms.Compose([ - transforms.Resize(input_size), + transform_items = [ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) - ]) + ] - return transform + if resize: + transform_items.insert(0, transforms.Resize(input_size)) + + return transforms.Compose(transform_items) class TorchTubDataset(IterableDataset): @@ -64,6 +66,7 @@ def __init__(self, config, records: List[TubRecord], transform=None): self.sequence = TubSequence(records) self.pipeline = self._create_pipeline() + self.len = len(records) def _create_pipeline(self): """ This can be overridden if more complicated pipelines are @@ -87,13 +90,14 @@ def x_transform(record: TubRecord): # Build pipeline using the transformations pipeline = self.sequence.build_pipeline(x_transform=x_transform, y_transform=y_transform) - return pipeline + def __len__(self): + return len(self.sequence) + def __iter__(self): return iter(self.pipeline) - class TorchTubDataModule(pl.LightningDataModule): def __init__(self, config: Any, tub_paths: List[str], transform=None): diff --git a/donkeycar/pipeline/training.py b/donkeycar/pipeline/training.py index 6aa064de7..b059c3d1c 100644 --- a/donkeycar/pipeline/training.py +++ b/donkeycar/pipeline/training.py @@ -81,7 +81,6 @@ def create_tf_data(self) -> tf.data.Dataset: output_shapes=self.model.output_shapes()) return dataset.repeat().batch(self.batch_size) - def get_model_train_details(database: PilotDatabase, model: str = None) \ -> Tuple[str, int]: if not model: @@ -108,7 +107,7 @@ def train(cfg: Config, tub_paths: str, model: str = None, if transfer: kl.load(transfer) if cfg.PRINT_MODEL_SUMMARY: - print(kl.interpreter.model.summary()) + print(kl.interpreter.summary()) tubs = tub_paths.split(',') all_tub_paths = [os.path.expanduser(tub) for tub in tubs] @@ -121,13 +120,23 @@ def train(cfg: Config, tub_paths: str, model: str = None, print(f'Records # Validation {len(validation_records)}') # We need augmentation in validation when using crop / trapeze - training_pipe = BatchSequence(kl, cfg, training_records, is_train=True) - validation_pipe = BatchSequence(kl, cfg, validation_records, is_train=False) - tune = tf.data.experimental.AUTOTUNE - dataset_train = training_pipe.create_tf_data().prefetch(tune) - dataset_validate = validation_pipe.create_tf_data().prefetch(tune) - train_size = len(training_pipe) - val_size = len(validation_pipe) + + if 'fastai_' in model_type: + from donkeycar.parts.pytorch.torch_data import TorchTubDataset, get_default_transform + transform = get_default_transform(resize=False) + dataset_train = TorchTubDataset(cfg, training_records, transform=transform) + dataset_validate = TorchTubDataset(cfg, validation_records, transform=transform) + train_size = len(training_records) + val_size = len(validation_records) + else: + training_pipe = BatchSequence(kl, cfg, training_records, is_train=True) + validation_pipe = BatchSequence(kl, cfg, validation_records, is_train=False) + tune = tf.data.experimental.AUTOTUNE + dataset_train = training_pipe.create_tf_data().prefetch(tune) + dataset_validate = validation_pipe.create_tf_data().prefetch(tune) + + train_size = len(training_pipe) + val_size = len(validation_pipe) assert val_size > 0, "Not enough validation data, decrease the batch " \ "size or add more data." @@ -162,7 +171,7 @@ def train(cfg: Config, tub_paths: str, model: str = None, 'Type': str(kl), 'Tubs': tub_paths, 'Time': time(), - 'History': history.history, + 'History': history, 'Transfer': os.path.basename(transfer) if transfer else None, 'Comment': comment, 'Config': str(cfg) @@ -170,4 +179,4 @@ def train(cfg: Config, tub_paths: str, model: str = None, database.add_entry(database_entry) database.write() - return history + return history \ No newline at end of file diff --git a/donkeycar/templates/complete.py b/donkeycar/templates/complete.py index 1c4b73607..418c8b6de 100644 --- a/donkeycar/templates/complete.py +++ b/donkeycar/templates/complete.py @@ -431,7 +431,7 @@ def load_model_json(kl, json_fnm): model_reload_cb = None if '.h5' in model_path or '.trt' in model_path or '.tflite' in \ - model_path or '.savedmodel' in model_path: + model_path or '.savedmodel' in model_path or '.pth': # load the whole model with weigths, etc load_model(kl, model_path) diff --git a/donkeycar/tests/test_train.py b/donkeycar/tests/test_train.py index de35d5fe3..c69c2b004 100644 --- a/donkeycar/tests/test_train.py +++ b/donkeycar/tests/test_train.py @@ -13,8 +13,8 @@ from donkeycar.utils import get_model_by_type, normalize_image, train_test_split Data = namedtuple('Data', - ['type', 'name', 'convergence', 'pretrained', 'preprocess'], - defaults=(None, ) * 5) + ['type', 'name', 'convergence', 'pretrained', 'preprocess', 'tf_lite', 'tensor_rt'], + defaults=(None, ) * 7) @pytest.fixture(scope='session') @@ -112,8 +112,9 @@ def car_dir(tmpdir_factory, base_config, imu_fields) -> str: d11 = Data(type='3d', name='3d1', convergence=0.6, pretrained=None) d12 = Data(type='linear', name='lin2', convergence=0.7, preprocess='aug') d13 = Data(type='linear', name='lin3', convergence=0.7, preprocess='trans') +d14 = Data(type='fastai_linear', name='linfastai1', convergence=0.6, pretrained=None, tf_lite=False, tensor_rt=False) -test_data = [d1, d2, d3, d6, d7, d8, d9, d10, d11, d12] +test_data = [d1, d2, d3, d6, d7, d8, d9, d10, d11, d12, d14] full_tub = ['imu', 'behavior', 'localizer'] @@ -140,8 +141,15 @@ def pilot_path(name): elif data.preprocess == 'trans': add_transformation_to_config(config) + if data.tf_lite is not None: + config.CREATE_TF_LITE = data.tf_lite + + if data.tensor_rt is not None: + config.CREATE_TENSOR_RT = data.tensor_rt + history = train(config, tub_dir, pilot_path(data.name), data.type) - loss = history.history['loss'] + loss = history['loss'] + # check loss is converging assert loss[-1] < loss[0] * data.convergence diff --git a/donkeycar/utils.py b/donkeycar/utils.py index 289fba75b..615d1f236 100644 --- a/donkeycar/utils.py +++ b/donkeycar/utils.py @@ -17,7 +17,7 @@ import time import signal import logging -from typing import List, Any, Tuple +from typing import List, Any, Tuple, Union from PIL import Image import numpy as np @@ -428,7 +428,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) -def get_model_by_type(model_type: str, cfg: 'Config') -> 'KerasPilot': +def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'FastAiPilot']: ''' given the string model_type and the configuration settings in cfg create a Keras model and return it. @@ -436,7 +436,10 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> 'KerasPilot': from donkeycar.parts.keras import KerasCategorical, KerasLinear, \ KerasInferred, KerasIMU, KerasMemory, KerasBehavioral, KerasLocalizer, \ KerasLSTM, Keras3D_CNN - from donkeycar.parts.interpreter import KerasInterpreter, TfLite, TensorRT + from donkeycar.parts.interpreter import KerasInterpreter, TfLite, TensorRT, \ + FastAIInterpreter + + from donkeycar.parts.fastai import FastAILinear if model_type is None: model_type = cfg.DEFAULT_MODEL_TYPE @@ -448,9 +451,15 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> 'KerasPilot': elif 'tensorrt_' in model_type: interpreter = TensorRT() used_model_type = model_type.replace('tensorrt_', '') + elif 'fastai_' in model_type: + interpreter = FastAIInterpreter() + used_model_type = model_type.replace('fastai_', '') + if used_model_type == "linear": + return FastAILinear(interpreter=interpreter, input_shape=input_shape) else: interpreter = KerasInterpreter() used_model_type = model_type + used_model_type = EqMemorizedString(used_model_type) if used_model_type == "linear": kl = KerasLinear(interpreter=interpreter, input_shape=input_shape) diff --git a/install/envs/mac.yml b/install/envs/mac.yml index 91919a515..b2dc3a9fb 100644 --- a/install/envs/mac.yml +++ b/install/envs/mac.yml @@ -4,6 +4,7 @@ channels: - defaults - conda-forge - pytorch + - fastai dependencies: - python=3.7 @@ -34,6 +35,7 @@ dependencies: - kivy=2.0.0 - plotly - pyyaml + - fastai - pip: - tensorflow==2.2.0 - git+https://github.com/autorope/keras-vis.git diff --git a/install/envs/ubuntu.yml b/install/envs/ubuntu.yml index 1b568eebc..d539a1e4c 100644 --- a/install/envs/ubuntu.yml +++ b/install/envs/ubuntu.yml @@ -4,6 +4,7 @@ channels: - defaults - conda-forge - pytorch + - fastai dependencies: - python=3.7 @@ -36,6 +37,7 @@ dependencies: - plotly - pyyaml - tensorflow=2.2.0 + - fastai - pip: - git+https://github.com/autorope/keras-vis.git - simple-pid diff --git a/install/envs/windows.yml b/install/envs/windows.yml index 11e55dea7..9d105c534 100644 --- a/install/envs/windows.yml +++ b/install/envs/windows.yml @@ -4,6 +4,7 @@ channels: - defaults - conda-forge - pytorch + - fastai dependencies: - python=3.7 @@ -35,6 +36,7 @@ dependencies: - plotly - pyyaml - psutil + - fastai - pip: - git+https://github.com/autorope/keras-vis.git - simple-pid diff --git a/setup.py b/setup.py index 447b5fa42..a1cb7ce5a 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,8 @@ def package_files(directory, strip_leading): 'torch': [ 'pytorch>=1.7.1', 'torchvision', - 'torchaudio' + 'torchaudio', + 'fastai' ], 'mm1': ['pyserial'] },