diff --git a/BUILD b/BUILD new file mode 100644 index 0000000..b7f8359 --- /dev/null +++ b/BUILD @@ -0,0 +1,63 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# DarwiNN Runtime Libaries. + +package(default_visibility = ["//visibility:public"]) + +# All Google Owned Code except : +# - certain files in port/default/ that are under Apache 2.0 license. +licenses(["notice"]) + +exports_files([ + "LICENSE", +]) + +# If --define darwinn_portable=1, compile without google3 deps. +config_setting( + name = "darwinn_portable", + values = { + "define": "darwinn_portable=1", + }, +) + +# If --define darwinn_portable=1 AND this is an otherwise non-portable config. +config_setting( + name = "darwinn_portable_with_non_portable_os", + flag_values = {"//tools/cpp:cc_target_os": "linux-google"}, + values = {"define": "darwinn_portable=1"}, +) + +# If --define darwinn_firmware=1, compile with minimal deps. +config_setting( + name = "darwinn_firmware", + values = { + "define": "darwinn_firmware=1", + }, +) + +config_setting( + name = "windows", + values = { + "cpu": "x64_windows", + }, +) + +config_setting( + name = "darwin", + values = { + "cpu": "darwin", + }, +) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..3414802 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +This project is not currently accepting contributions. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement (CLA). You (or your employer) retain the copyright to your +contribution; this simply gives us permission to use and redistribute your +contributions as part of the project. Head over to + to see your current agreements on file or +to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4e1e782 --- /dev/null +++ b/Makefile @@ -0,0 +1,115 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +SHELL := /bin/bash +MAKEFILE_DIR := $(realpath $(dir $(lastword $(MAKEFILE_LIST)))) +OUT_DIR := $(MAKEFILE_DIR)/out +OS := $(shell uname -s) + +ifeq ($(OS),Linux) +CPU ?= k8 +else ifeq ($(OS),Darwin) +CPU ?= darwin +else +$(error $(OS) is not supported) +endif + +ifeq ($(filter $(CPU),k8 armv6 armv7a aarch64 darwin),) +$(error CPU must be k8, armv7a, armv6, aarch64, or darwin) +endif + +COMPILATION_MODE ?= opt +ifeq ($(filter $(COMPILATION_MODE),opt dbg),) +$(error COMPILATION_MODE must be opt or dbg) +endif + +BAZEL_OUT_DIR := $(MAKEFILE_DIR)/bazel-out/$(CPU)-$(COMPILATION_MODE)/bin + +# Linux-specific parameters +BAZEL_BUILD_TARGET_Linux := //tflite/public:libedgetpu_direct_all.so +# --experimental_repo_remote_exec for remotable parameter used in +# --`repository_rule` from TF. +BAZEL_BUILD_FLAGS_Linux := --crosstool_top=@crosstool//:toolchains \ + --compiler=gcc \ + --linkopt=-l:libusb-1.0.so \ + --experimental_repo_remote_exec +BAZEL_BUILD_OUTPUT_FILE_Linux := libedgetpu.so.1.0 +BAZEL_BUILD_OUTPUT_SYMLINK_Linux := libedgetpu.so.1 + +ifeq ($(COMPILATION_MODE), opt) +BAZEL_BUILD_FLAGS_Linux += --linkopt=-Wl,--strip-all +endif +ifeq ($(CPU), armv6) +BAZEL_BUILD_FLAGS_Linux += --linkopt=-L/usr/lib/arm-linux-gnueabihf/ +endif + +# Darwin-specific parameters +BAZEL_BUILD_TARGET_Darwin := //tflite/public:libedgetpu_direct_usb.dylib +BAZEL_BUILD_FLAGS_Darwin := --linkopt=-L/opt/local/lib \ + --linkopt=-lusb-1.0 \ + --copt=-fvisibility=hidden +BAZEL_BUILD_OUTPUT_FILE_Darwin := libedgetpu.1.0.dylib +BAZEL_BUILD_OUTPUT_SYMLINK_Darwin := libedgetpu.1.dylib + +# Common parameters +BAZEL_BUILD_FLAGS := --sandbox_debug --subcommands \ + --compilation_mode=$(COMPILATION_MODE) \ + --define darwinn_portable=1 \ + --copt=-DSTRIP_LOG=1 \ + --copt=-DEDGETPU_EXTERNAL_RELEASE_RUNTIME \ + --copt=-fno-rtti \ + --copt=-fno-exceptions \ + --copt='-D__FILE__=""' \ + --cpu=$(CPU) +BAZEL_BUILD_FLAGS += $(BAZEL_BUILD_FLAGS_$(OS)) +BAZEL_BUILD_TARGET := $(BAZEL_BUILD_TARGET_$(OS)) +BAZEL_BUILD_OUTPUT_FILE := $(BAZEL_BUILD_OUTPUT_FILE_$(OS)) +BAZEL_BUILD_OUTPUT_SYMLINK := $(BAZEL_BUILD_OUTPUT_SYMLINK_$(OS)) + +define copy_out +mkdir -p $(OUT_DIR)/$(1)/$(CPU) && \ +cp -f $(BAZEL_OUT_DIR)/tflite/public/*$(suffix $(BAZEL_BUILD_TARGET)) \ + $(OUT_DIR)/$(1)/$(CPU)/$(BAZEL_BUILD_OUTPUT_FILE) && \ +ln -fs $(BAZEL_BUILD_OUTPUT_FILE) \ + $(OUT_DIR)/$(1)/$(CPU)/$(BAZEL_BUILD_OUTPUT_SYMLINK) +endef + +ifeq ($(OS),Darwin) +ifeq ($(COMPILATION_MODE),opt) +define strip_out +strip -x -S -o $(OUT_DIR)/$(1)/$(CPU)/$(BAZEL_BUILD_OUTPUT_FILE) \ + $(OUT_DIR)/$(1)/$(CPU)/$(BAZEL_BUILD_OUTPUT_FILE) +endef +endif +endif + +libedgetpu: libedgetpu-direct libedgetpu-throttled + +libedgetpu-direct: + bazel build $(BAZEL_BUILD_FLAGS) $(BAZEL_BUILD_TARGET) + $(call copy_out,direct) + $(call strip_out,direct) + +libedgetpu-throttled: + bazel build $(BAZEL_BUILD_FLAGS) --copt=-DTHROTTLE_EDGE_TPU $(BAZEL_BUILD_TARGET) + $(call copy_out,throttled) + $(call strip_out,throttled) + +clean: + rm -rf $(OUT_DIR) + +ifdef DOCKER_MK +DOCKER_WORKSPACE := $(MAKEFILE_DIR) +DOCKER_TAG_BASE=coral-libedgetpu +include $(DOCKER_MK) +endif diff --git a/README.md b/README.md index 38f3e77..16da0e5 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,33 @@ -# libedgetpu -Source code for the userspace level runtime driver for Coral.ai devices. +# Edge TPU runtime library (libedgetpu) + +This repo contains the source code for the userspace +level runtime driver for [Coral devices](https://coral.ai/products). +This software is distributed in the binary form at [coral.ai/software](https://coral.ai/software/). + +## Building + +At present only Bazel build system is supported, but it can be invoked from the Makefile. + +## Support + +If you have question, comments or requests concerning this library, please +reach out to coral-support@google.com. + +## License + +[Apache License 2.0](LICENSE) + +## Warning + +If you're using the Coral USB Accelerator, it may heat up during operation, depending +on the computation workloads and operating frequency. Touching the metal part of the USB +Accelerator after it has been operating for an extended period of time may lead to discomfort +and/or skin burns. As such, if you enable the Edge TPU runtime using the maximum operating +frequency, the USB Accelerator should be operated at an ambient temperature of 25°C or less. +Alternatively, if you enable the Edge TPU runtime using the reduced operating frequency, then +the device is intended to safely operate at an ambient temperature of 35°C or less. + +Google does not accept any responsibility for any loss or damage if the device +is operated outside of the recommended ambient temperature range. + +Note: This issue affects only USB-based Coral devices, and is irrelevant for PCIe devices. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..7060b32 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,75 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +workspace(name = "libedgetpu") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "io_bazel_rules_closure", + sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", + strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", + urls = [ + "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 + ], +) + +# Be consistent with tensorflow/WORKSPACE. +http_archive( + name = "bazel_skylib", + sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz"], +) # https://github.com/bazelbuild/bazel-skylib/releases + +# The TF commit # here must be in sync with that specified under Gob edgetpu +# repo WORKSPACE file. +# TODO: figure out a way to keep single source of truth of the +# TF commit # used. +TENSORFLOW_COMMIT = "f394a768719a55b5c351ed1ecab2ec6f16f99dd4"; +# Command to calculate: curl -OL | sha256sum | awk '{print $1}' +TENSORFLOW_SHA256 = "cb286abee7ee9cf5c8701d85fcc88f0fd59e72492ec4f254156de486e3e905c1" +http_archive( + name = "org_tensorflow", + sha256 = TENSORFLOW_SHA256, + strip_prefix = "tensorflow-" + TENSORFLOW_COMMIT, + urls = [ + "https://github.com/tensorflow/tensorflow/archive/" + TENSORFLOW_COMMIT + ".tar.gz", + ], +) + +load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") +tf_workspace(tf_repo_name = "org_tensorflow") + +http_archive( + name = "coral_crosstool", + sha256 = "cb31b1417ccdcf7dd9fca5ec63e1571672372c30427730255997a547569d2feb", + strip_prefix = "crosstool-9e00d5be43bf001f883b5700f5d04882fea00229", + urls = [ + "https://github.com/google-coral/crosstool/archive/9e00d5be43bf001f883b5700f5d04882fea00229.tar.gz", + ], +) +load("@coral_crosstool//:configure.bzl", "cc_crosstool") +cc_crosstool(name = "crosstool") +new_local_repository( + name = "libusb", + path = "/usr/include/", + build_file_content = """ +cc_library( + name = "headers", + includes = ["."], + hdrs = ["libusb-1.0/libusb.h"], + visibility = ["//visibility:public"], +) +""" +) diff --git a/api/BUILD b/api/BUILD new file mode 100644 index 0000000..9f06da5 --- /dev/null +++ b/api/BUILD @@ -0,0 +1,179 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Darwinn API headers +load( + "@flatbuffers//:build_defs.bzl", + "flatbuffer_cc_library", +) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "chip", + hdrs = ["chip.h"], +) + +cc_library( + name = "tensor_util", + srcs = ["tensor_util.cc"], + hdrs = ["tensor_util.h"], + deps = [ + "//executable:executable_fbs", + "//port", + "//port:string_util", + ], +) + +cc_library( + name = "layer_information", + srcs = ["layer_information.cc"], + hdrs = ["layer_information.h"], + deps = [ + ":buffer", + ":tensor_util", + "//executable:executable_fbs", + "//port", + ], +) + +cc_library( + name = "request", + hdrs = ["request.h"], + deps = [ + ":buffer", + "//port", + ], +) + +cc_library( + name = "driver", + hdrs = ["driver.h"], + deps = [ + ":buffer", + ":package_reference", + ":request", + ":timing", + "//api:telemeter_interface", + "//port", + ], +) + +cc_library( + name = "driver_options_helper", + srcs = ["driver_options_helper.cc"], + hdrs = ["driver_options_helper.h"], + deps = [ + ":driver", + ":driver_options_fbs", + ], +) + +cc_library( + name = "allocated_buffer", + srcs = ["allocated_buffer.cc"], + hdrs = ["allocated_buffer.h"], + deps = ["//port"], +) + +cc_library( + name = "dram_buffer", + hdrs = ["dram_buffer.h"], + deps = ["//port"], +) + +cc_library( + name = "buffer", + srcs = ["buffer.cc"], + hdrs = ["buffer.h"], + deps = [ + ":allocated_buffer", + ":dram_buffer", + "//port", + ], +) + +cc_library( + name = "driver_factory", + srcs = ["driver_factory.cc"], + hdrs = ["driver_factory.h"], + deps = [ + ":chip", + ":driver", + ":driver_options_fbs", + "//port", + ], +) + +cc_library( + name = "package_reference", + hdrs = ["package_reference.h"], + deps = [ + ":execution_context_interface", + ":layer_information", + "//executable:executable_fbs", + "//port", + ], +) + +cc_library( + name = "runtime_version", + hdrs = ["runtime_version.h"], +) + +flatbuffer_cc_library( + name = "driver_options_fbs", + srcs = ["driver_options.fbs"], + flatc_args = [""], +) + +cc_library( + name = "timing", + hdrs = ["timing.h"], + deps = [ + "//port", + ], +) + +cc_library( + name = "watchdog", + srcs = ["watchdog.cc"], + hdrs = ["watchdog.h"], + deps = [ + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:timer", + ], +) + +cc_library( + name = "telemeter_interface", + hdrs = [ + "telemeter_interface.h", + ], + deps = [ + ":execution_context_interface", + ], +) + +cc_library( + name = "execution_context_interface", + hdrs = [ + "execution_context_interface.h", + ], +) diff --git a/api/allocated_buffer.cc b/api/allocated_buffer.cc new file mode 100644 index 0000000..4a00d31 --- /dev/null +++ b/api/allocated_buffer.cc @@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/allocated_buffer.h" + +#include "port/logging.h" + +namespace platforms { +namespace darwinn { + +AllocatedBuffer::AllocatedBuffer(unsigned char* ptr, size_t size_bytes, + FreeCallback free_callback) + : ptr_(ptr), + size_bytes_(size_bytes), + free_callback_(std::move(free_callback)) { + CHECK(ptr != nullptr); +} + +AllocatedBuffer::~AllocatedBuffer() { free_callback_(ptr_); } + +} // namespace darwinn +} // namespace platforms diff --git a/api/allocated_buffer.h b/api/allocated_buffer.h new file mode 100644 index 0000000..97740f0 --- /dev/null +++ b/api/allocated_buffer.h @@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_ALLOCATED_BUFFER_H_ +#define DARWINN_API_ALLOCATED_BUFFER_H_ + +#include + +namespace platforms { +namespace darwinn { + +// A type for buffer that holds (owns) allocated host memory. This class takes +// ownership of the buffer pointers passed into it, freeing them using the given +// Free function when destroyed. +class AllocatedBuffer { + public: + // A type for the callback executed to free the buffer. + using FreeCallback = std::function; + + AllocatedBuffer(unsigned char* ptr, size_t size_bytes, + FreeCallback free_callback); + + ~AllocatedBuffer(); + + // Not copyable or movable + AllocatedBuffer(const AllocatedBuffer&) = delete; + AllocatedBuffer& operator=(const AllocatedBuffer&) = delete; + + // Returns const buffer pointer. + const unsigned char* ptr() const { return ptr_; } + + // Returns buffer pointer. + unsigned char* ptr() { return ptr_; } + + // Size of this buffer in bytes. + size_t size_bytes() const { return size_bytes_; } + + private: + // Points to allocated buffer. + unsigned char* ptr_; + + // Size of the buffer. + size_t size_bytes_; + + // Callback executed to free the buffer. + FreeCallback free_callback_; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_ALLOCATED_BUFFER_H_ diff --git a/api/buffer.cc b/api/buffer.cc new file mode 100644 index 0000000..0e9c01f --- /dev/null +++ b/api/buffer.cc @@ -0,0 +1,173 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/buffer.h" + +#include + +#include "api/allocated_buffer.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { + +Buffer::Buffer(unsigned char* buffer, size_t size_bytes) + : type_(Type::kWrapped), size_bytes_(size_bytes), ptr_(buffer) {} + +Buffer::Buffer(const unsigned char* buffer, size_t size_bytes) + : Buffer(const_cast(buffer), size_bytes) {} + +Buffer::Buffer(void* buffer, size_t size_bytes) + : Buffer(reinterpret_cast(buffer), size_bytes) {} + +Buffer::Buffer(const void* buffer, size_t size_bytes) + : Buffer(const_cast(buffer), size_bytes) {} + +Buffer::Buffer(int fd, size_t size_bytes, bool on_device_dram) + : type_(on_device_dram ? Type::kDramWrapped : Type::kFileDescriptor), + size_bytes_(size_bytes), + file_descriptor_(fd) {} + +Buffer::Buffer(std::shared_ptr allocated_buffer) + : type_(Type::kAllocated), + size_bytes_(allocated_buffer->size_bytes()), + ptr_(allocated_buffer->ptr()), + allocated_buffer_(std::move(allocated_buffer)) {} + +Buffer::Buffer(std::shared_ptr dram_buffer) + : type_(Type::kDram), + size_bytes_(dram_buffer->size_bytes()), + file_descriptor_(dram_buffer->fd()), + dram_buffer_(std::move(dram_buffer)) {} + +bool Buffer::operator==(const Buffer& rhs) const { + return type_ == rhs.type_ && size_bytes_ == rhs.size_bytes_ && + ptr_ == rhs.ptr_ && allocated_buffer_ == rhs.allocated_buffer_; +} + +bool Buffer::operator!=(const Buffer& rhs) const { return !(*this == rhs); } + +Buffer::Buffer(Buffer&& other) + : type_(other.type_), + size_bytes_(other.size_bytes_), + ptr_(other.ptr_), + allocated_buffer_(std::move(other.allocated_buffer_)), + file_descriptor_(other.file_descriptor_), + dram_buffer_(std::move(other.dram_buffer_)) { + // Explicitly clear out other. + other.type_ = Type::kInvalid; + other.ptr_ = 0; + other.size_bytes_ = 0; + other.file_descriptor_ = -1; + // other.allocated_buffer handled in move above. +} + +Buffer& Buffer::operator=(Buffer&& other) { + if (this != &other) { + type_ = other.type_; + size_bytes_ = other.size_bytes_; + ptr_ = other.ptr_; + file_descriptor_ = other.file_descriptor_; + allocated_buffer_ = std::move(other.allocated_buffer_); + dram_buffer_ = std::move(other.dram_buffer_); + + // Explicitly clear out other. + other.type_ = Type::kInvalid; + other.ptr_ = 0; + other.file_descriptor_ = 0; + other.size_bytes_ = 0; + // other.allocated_buffer handled in move above. + } + return *this; +} + +Buffer Buffer::Slice(size_t offset, size_t length) const { + CHECK_LE(offset + length, size_bytes_); + CHECK(!FileDescriptorBacked() || offset == 0); + + Buffer ret = *this; + ret.ptr_ += offset; + ret.size_bytes_ = length; + return ret; +} + +const unsigned char* Buffer::ptr() const { + // FD and DRAM type Buffers need to be mapped before use. + if (type_ == Type::kFileDescriptor || + type_ == Type::kDram || + type_ == Type::kDramWrapped) { + LOG(FATAL) << "Called ptr() on buffer type " << type_; + } + + return ptr_; +} + +unsigned char* Buffer::ptr() { + // FD and DRAM type Buffers need to be mapped before use. + if (type_ == Type::kFileDescriptor || + type_ == Type::kDram || + type_ == Type::kDramWrapped) { + LOG(FATAL) << "Called ptr() on buffer type " << type_; + } + + return ptr_; +} + +int Buffer::fd() const { + // Only valid with Type == kFileDescriptor, kDram or kDramWrapped + if (type_ != Type::kFileDescriptor && + type_ != Type::kDram && + type_ != Type::kDramWrapped) { + LOG(FATAL) << "Called fd() on buffer type " << type_; + } + + return file_descriptor_; +} + +util::StatusOr> Buffer::GetDramBuffer() { + if (type_ != Type::kDram) { + return util::FailedPreconditionError( + StringPrintf("Called GetDramBuffer on a buffer of type %d.", type_)); + } + return dram_buffer_; +} + +std::string Buffer::ToString() const { + if (FileDescriptorBacked()) { + return StringPrintf("Buffer(fd=%d)", file_descriptor_); + } else { + return StringPrintf("Buffer(ptr=%p)", ptr_); + } +} + +std::ostream& operator<<(std::ostream& stream, const Buffer::Type& type) { + switch (type) { + case Buffer::Type::kInvalid: + return (stream << "kInvalid"); + case Buffer::Type::kWrapped: + return (stream << "kWrapped"); + case Buffer::Type::kAllocated: + return (stream << "kAllocated"); + case Buffer::Type::kFileDescriptor: + return (stream << "kFileDescriptor"); + case Buffer::Type::kDram: + return (stream << "kDram"); + case Buffer::Type::kDramWrapped: + return (stream << "kDramWrapped"); + } +} +} // namespace darwinn +} // namespace platforms diff --git a/api/buffer.h b/api/buffer.h new file mode 100644 index 0000000..2a9ee29 --- /dev/null +++ b/api/buffer.h @@ -0,0 +1,183 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_BUFFER_H_ +#define DARWINN_API_BUFFER_H_ + +#include +#include +#include +#include +#include + +#include "api/allocated_buffer.h" +#include "api/dram_buffer.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { + +// Abstracts a buffer. Movable and copyable. +// TODO: Consider adding two different variants of this class +// for indicating Const and Mutable variants (like ArraySlice). For now, +// const Buffer, requires that contents of underlying buffer is const. +class Buffer { + public: + // Convenience structure for keeping track of named array of Buffers. + using NamedMap = std::unordered_map>; + + // Default constructor. Defaults to an invalid non-existent buffer. + Buffer() = default; + + // Constructors for wrapping an existing host buffer. + Buffer(void* buffer, size_t size_bytes); + Buffer(const void* buffer, size_t size_bytes); + + // Constructors for wrapping an existing host buffer, and optionally hints + // the runtime to cache it in on-chip memory outside the DarwiNN core. + Buffer(unsigned char* buffer, size_t size_bytes); + Buffer(const unsigned char* buffer, size_t size_bytes); + + // Constructors for wrapping an allocated buffer. + explicit Buffer(std::shared_ptr allocated_buffer); + + // Constructor for wrapping a file descriptor for existing memory. + // on_device_dram: =true, the allocated memory is on DRAM, + // =false, the allocated memory is mmap-able shared memory. + Buffer(int fd, size_t size_bytes, bool on_device_dram = false); + + // Constructors for wrapping an on-chip DRAM buffer. + explicit Buffer(std::shared_ptr dram_buffer); + + // This type is copyable, with default implementations. + Buffer(const Buffer&) = default; + Buffer& operator=(const Buffer&) = default; + + // This type is movable. + Buffer(Buffer&& other); + Buffer& operator=(Buffer&& other); + + // Destructors. + ~Buffer() = default; + + // Get a slice of this buffer. Note that this does not resize the underlying + // storage, and the original buffer is still valid. The slice will be of the + // same type as this buffer. In particular, that means there will be an + // additional shared_ptr reference to the backing memory for allocated + // buffers. + // TODO: File descriptor-based buffers cannot be sliced unless + // the offset is 0. + Buffer Slice(size_t offset, size_t length) const; + + // Size of this buffer in bytes. + size_t size_bytes() const { return size_bytes_; } + + // Returns true if buffer is valid. + bool IsValid() const { return type_ != Type::kInvalid; } + + // Returns buffer pointer. + const unsigned char* ptr() const; + + // Returns buffer pointer. + unsigned char* ptr(); + + // Returns true if the buffer is backed by some host memory, may or may not be + // owned by this Buffer. + bool IsPtrType() const { + return type_ == Type::kWrapped || type_ == Type::kAllocated; + } + + // Returns file descriptor. + int fd() const; + + // Returns true if the buffer is backed by a file descriptor. + bool FileDescriptorBacked() const { + return type_ == Type::kFileDescriptor || + type_ == Type::kDram || + type_ == Type::kDramWrapped; + } + + // Returns true if this buffer is backed by a DramBuffer. + bool IsDramType() const { + return type_ == Type::kDram || type_ == Type::kDramWrapped; + } + + // Returns true if the buffer is managed by the runtime, + // i.e., the buffer does not wrap existing memory allocated + // outside the runtime. + bool IsManagedType() const { + return type_ == Type::kAllocated || type_ == Type::kDram; + } + + // Returns the underlying DRAM Buffer if this buffer is wrapping one managed + // by the runtime. + util::StatusOr> GetDramBuffer(); + + // Returns a string representation of the buffer for logging/debugging. + std::string ToString() const; + + // Equality operators. + bool operator==(const Buffer& rhs) const; + bool operator!=(const Buffer& rhs) const; + + private: + // Type for the buffer. + enum class Type { + // Invalid. + kInvalid = 0, + + // Wraps an existing host process addressable buffer. + kWrapped = 1, + + // Wraps an allocated host process addressable buffer. + kAllocated = 2, + + // Wraps an mmap-able file descriptor, possibly from ION. + kFileDescriptor = 3, + + // Wraps a buffer allocated from on-chip DRAM and managed by the runtime. + kDram = 4, + + // Wraps an existing, i.e., externally allocated, on-chip DRAM + // allocated buffer not managed by the runtime. + kDramWrapped = 5, + }; + + // To allow Buffer::Type being tested in CHECK() et al. + friend std::ostream& operator<< (std::ostream& stream, const Type& type); + + // Type for the buffer. + Type type_{Type::kInvalid}; + + // Size of the buffer. + size_t size_bytes_{0}; + + // Points to host buffer. Valid when type is kWrapped / kAllocated. + unsigned char* ptr_{nullptr}; + + // Points to allocated buffer. Valid when type is kAllocated. + std::shared_ptr allocated_buffer_; + + // File descriptor. Valid when type is kFileDescriptor, kDram or kDramWrapped. + // Reset to -1 when moved. + int file_descriptor_{-1}; + + // Points to the DramBuffer. Valid when type is kDram. + std::shared_ptr dram_buffer_; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_BUFFER_H_ diff --git a/api/chip.h b/api/chip.h new file mode 100644 index 0000000..7cfdfa8 --- /dev/null +++ b/api/chip.h @@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_CHIP_H_ +#define DARWINN_API_CHIP_H_ + +#include + +#include + +namespace platforms { +namespace darwinn { +namespace api { + +// Target chip for runtime stack. +enum class Chip { + kBeagle, + kUnknown, +}; + +static const struct { + Chip chip; + const char* names[2]; +} kChipNames[] = { + {Chip::kBeagle, {"beagle", "beagle_fpga"}}, +}; + +// Returns correct Chip for given |chip_name|. +static inline Chip GetChipByName(const char* chip_name) { + for (auto& pair : kChipNames) + for (auto name : pair.names) + if (name != nullptr && strcmp(chip_name, name) == 0) return pair.chip; + return Chip::kUnknown; +} + +// Returns correct Chip for given |chip_name|. +static inline Chip GetChipByName(const std::string& chip_name) { + return GetChipByName(chip_name.c_str()); +} + +// Returns the name of the given |chip|. +static inline std::string GetChipName(Chip chip) { + for (auto& pair : kChipNames) + if (pair.chip == chip) return pair.names[0]; + return "unknown"; +} + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_CHIP_H_ diff --git a/api/dram_buffer.h b/api/dram_buffer.h new file mode 100644 index 0000000..54f0ea3 --- /dev/null +++ b/api/dram_buffer.h @@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_DRAM_BUFFER_H_ +#define DARWINN_API_DRAM_BUFFER_H_ + +#include "port/status.h" + +namespace platforms { +namespace darwinn { + +// Represents a buffer backed by on-chip DRAM. +class DramBuffer { + public: + DramBuffer() = default; + virtual ~DramBuffer() = default; + + // Returns the file descriptor to the DRAM buffer. + virtual int fd() const = 0; + + // Returns size of the buffer in bytes. + virtual size_t size_bytes() const = 0; + + // Copies size_bytes() bytes of data from source to the buffer. + virtual util::Status ReadFrom(void* source) = 0; + + // Copies size_bytes() bytes ofr data from buffer to the destination address. + virtual util::Status WriteTo(void* destination) = 0; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_DRAM_BUFFER_H_ diff --git a/api/driver.h b/api/driver.h new file mode 100644 index 0000000..58aa671 --- /dev/null +++ b/api/driver.h @@ -0,0 +1,241 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_DRIVER_H_ +#define DARWINN_API_DRIVER_H_ + +#include +#include +#include +#include +#include + +#include "api/buffer.h" +#include "api/package_reference.h" +#include "api/request.h" +#include "api/telemeter_interface.h" +#include "api/timing.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// DarwiNN driver. Thread-safe, but not all functions can be called in +// callback context. +// +// Typical usage: +// Driver driver = driverFactory.get(); +// +// m1 = driver.RegisterExecutable() +// m2 = driver.RegisterExecutable() +// +// driver.Open(); +// r1 = driver.CreateRequest(m1, done_callback); +// r2 = driver.CreateRequest(m1, done_callback); +// driver.Submit(r1); +// driver.Submit(r2). +// driver.Close(); +//... +// After some time, application can try to re-open the driver again. +// driver.Open(); +// ... +// driver.Close(); +class Driver { + public: + // Callback for thermal warnings. Set with SetThermalWarningCallback(). + using ThermalWarningCallback = std::function; + + // Callback for fatal, unrecoverable failure. Set with + // SetFatalErrorCallback(). + using FatalErrorCallback = std::function; + + // Driver options. Opaque pointer to an options::Options FB object. + using Options = std::vector; + + // Current driver option version. Should match the version in + // driver_options.fbs. + static constexpr int kOptionsVersion = 1; + + // Specifies how driver should be closed. + enum class ClosingMode { + // Lets the active requests (the ones that have started DMA) finish and + // cancels pending requests. This may take a few milliseconds. + kGraceful = 0, + + // Cancels all active and pending requests. This is the fastest way we can + // close the driver without risk of crashing. + kAsap = 1, + }; + + // Specifies the way a model is preferred to be ran in terms of power/ + // performance trade-off. This can mapped to equivalent settings in higher + // level APIs (e.g. PreferenceCode in NNAPI). Please note that the enum + // integer values may be different from those in NNAPI or other APIs. The + // values here are defined in the order of priority when there are multiple + // models requesting different preferences (e.g. sustained speed takes + // priority over low power). For more information, please see: + // http://go/noronha-execution-preference + enum class ExecutionPreference { + // Run at the absolute maximum performance. + kSingleFastAnswer = 0, + + // Ideal for cases in which we are trying to optimize for power. + kLowPower = 1, + + // Run at the maximum performance but in a way that does not require power / + // thermal throttling in the long run. + kSustainedSpeed = 2, + }; + + // Encapsulates different TPU (and related components) operational settings + // that can impact runtime behavior. + struct OperationalSettings { + // TPU clock-rate in hertz. + int64 tpu_frequency_hz; + + // Data transfer bandwidth between host DRAM and TPU in bytes per second. + int64 host_to_tpu_bps; + }; + + Driver() = default; + virtual ~Driver() = default; + + // This class is neither copyable nor movable. + Driver(const Driver&) = delete; + Driver& operator=(const Driver&) = delete; + + // Returns true if the driver is open state. + virtual bool IsOpen() const = 0; + + // Returns true if underlying hardware is in an error state. + virtual bool IsError() const = 0; + + // Registers a file containing pre-compiled DarwiNN executable and returns a + // reference to the registered executable. The reference can be used for + // constructing requests later on. + virtual util::StatusOr RegisterExecutableFile( + const std::string& executable_filename) = 0; + + // Registers a string with serialized contents of a pre-compiled DarwiNN + // executable and returns a reference to the registered executable. The + // reference can be used for constructing requests later on. + virtual util::StatusOr RegisterExecutableSerialized( + const std::string& executable_content) = 0; + virtual util::StatusOr RegisterExecutableSerialized( + const char* executable_content, size_t length) = 0; + + // Unregisters a previously registered model. + virtual util::Status UnregisterExecutable( + const PackageReference* executable_ref) = 0; + + // Opens and initializes the underlying hardware. If debug_mode is true, + // the hardware is setup for use with a debugger. If context_lost is true + // driver assumes all data on chip (e.g. on DRAM) a from previous open has + // been lost. + virtual util::Status Open(bool debug_mode = false, + bool context_lost = false) = 0; + + // Creates a request object initialized with the given ExecutableReference. + virtual util::StatusOr> CreateRequest( + const PackageReference* executable_ref) = 0; + + // Submits a request for asynchronous execution. On success, done_callback + // will eventually be executed with the request status. The caller is expected + // to exit the done_callback as soon as possible. It is acceptable to only + // call #Submit() in the context of this callback. + virtual util::Status Submit(std::shared_ptr request, + Request::Done done_callback) = 0; + + // Executes a request synchronously. Calling thread will block until execution + // is complete. + virtual util::Status Execute(std::shared_ptr request) = 0; + + // Executes a series of requests synchronously in the given order. Calling + // thread will block until execution is complete. + virtual util::Status Execute( + const std::vector>& request) = 0; + + // Attempts to cancel a request. This is best effort cancellation. As in, + // requests already submitted to the hardware will be allowed to complete. + // Other requests will be cancelled, and will invoke done_callback with + // cancelled error. + virtual util::Status Cancel(std::shared_ptr request) = 0; + + // Best effort cancellation of all submitted requests. + virtual util::Status CancelAllRequests() = 0; + + // Closes and shutdowns underlying hardware possibly, switching it off. + // Pending requests are cancelled or completed and callbacks issued. Once + // closed, requests can no longer be submitted. + virtual util::Status Close(ClosingMode mode) = 0; + + // Buffer allocation alignment and granularity. + // Buffers allocated with this alignment may avoid additional copies within + // the driver. + virtual uint64_t allocation_alignment_bytes() const = 0; + + // Allocates size_bytes bytes and returns a Buffer for application use. The + // allocated memory is tied to the lifecycle of the Buffer object which in + // turn is tied to the life cycle of the driver instance. + virtual Buffer MakeBuffer(size_t size_bytes) const = 0; + + // Sets the callback for fatal, unrecoverable failure. When a fatal + // error is raised, the driver is pushed into an error state. All new + // submitted requests will fail. Application can generate a bug report and + // should close the driver, at which point all pending requests will fail and + // their callbacks executed. + virtual void SetFatalErrorCallback(FatalErrorCallback callback) = 0; + + // Sets the callback for thermal warnings. Application may be required to + // to reduce performance level and/or throttle new requests. + virtual void SetThermalWarningCallback(ThermalWarningCallback callback) = 0; + + // Enters/leaves real-time mode, if applicable. This is best effort as it + // relies on user provided timing information, and the fact that current + // generations of DarwiNN is not preemptable. + virtual util::Status SetRealtimeMode(bool on) = 0; + + // Sets expected arrival rates and max execution time (in milliseconds) for an + // package. Only used in real-time mode. + virtual util::Status SetExecutableTiming( + const api::PackageReference* executable, const api::Timing& timing) = 0; + + // Sets the provided execution preference for the provided package. Execution + // preferences are hints to the driver for how to adjust its settings in + // accordance with power/perf trade-off. Driver will try to keep all requested + // preferences satisfied erring on the side of performance. + virtual util::Status SetExecutionPreference( + const api::PackageReference* package, ExecutionPreference preference) = 0; + + // Sets the perferred telemeter interface. This interface is platform + // specific. By default, telemetry operations are NOPs. The + // telemeter_interface is not owned by api::Driver, so the telemeter's + // lifetime must remain valid as long as the driver object is valid. + virtual void SetTelemeterInterface( + api::TelemeterInterface* telemeter_interface) = 0; + + // Updates the operational settings in the driver. This method is to be called + // when any of these settings change (e.g. due to thermal throttling). + virtual void UpdateOperationalSettings( + const OperationalSettings& settings) = 0; + + // TODO: Add function for dumping bugreport. +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_DRIVER_H_ diff --git a/api/driver_factory.cc b/api/driver_factory.cc new file mode 100644 index 0000000..9a037f1 --- /dev/null +++ b/api/driver_factory.cc @@ -0,0 +1,90 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/driver_factory.h" + +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace api { + +namespace { + +Driver* NewDriver(const Device& device, const Driver::Options& options) { + // Build Driver. + auto factory = DriverFactory::GetOrCreate(); + auto backing_driver = factory->CreateDriver(device, options).ValueOrDie(); + return backing_driver.release(); +} + +} // namespace + +Driver* DriverFactory::CreateDriverAsSingleton(const Device& device, + const Driver::Options& options) { + static api::Driver* driver = NewDriver(device, options); + return driver; +} + +Device::Type GetTypeByName(const std::string& device_type) { + if (device_type == "PCI" || device_type == "pci") { + return Device::Type::PCI; + } else if (device_type == "USB" || device_type == "usb") { + return Device::Type::USB; + } else if (device_type == "PLATFORM" || device_type == "platform") { + return Device::Type::PLATFORM; + } else if (device_type == "REFERENCE" || device_type == "reference") { + return Device::Type::REFERENCE; + } else if (device_type == "SIMULATOR" || device_type == "simulator") { + return Device::Type::SIMULATOR; + } else if (device_type == "REMOTE_PCI" || device_type == "remote_pci") { + return Device::Type::REMOTE_PCI; + } else if (device_type == "REMOTE_USB" || device_type == "remote_usb") { + return Device::Type::REMOTE_USB; + } else if (device_type == "REMOTE_PLATFORM" || + device_type == "remote_platform") { + return Device::Type::REMOTE_PLATFORM; + } else { + LOG(FATAL) << "Unknown device type: " << device_type + << R"error(, which should be either "PCI", "USB", "PLATFORM", "REFERENCE", "REMOTE_PCI", "REMOTE_USB", "REMOTE_PLATFORM", or "SIMULATOR")error"; + unreachable(); // NOLINT + } +} + +std::string GetTypeName(Device::Type device_type) { + switch (device_type) { + case Device::Type::PCI: + return "pci"; + case Device::Type::USB: + return "usb"; + case Device::Type::PLATFORM: + return "platform"; + case Device::Type::REMOTE_PCI: + return "remote_pci"; + case Device::Type::REMOTE_USB: + return "remote_usb"; + case Device::Type::REMOTE_PLATFORM: + return "remote_platform"; + case Device::Type::REFERENCE: + return "reference"; + case Device::Type::SIMULATOR: + return "simulator"; + default: + return "unknown"; + } +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/api/driver_factory.h b/api/driver_factory.h new file mode 100644 index 0000000..19c4213 --- /dev/null +++ b/api/driver_factory.h @@ -0,0 +1,131 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_DRIVER_FACTORY_H_ +#define DARWINN_API_DRIVER_FACTORY_H_ + +#include +#include +#include +#include + +#include "api/chip.h" +#include "api/driver.h" +#include "api/driver_options_generated.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// A type for uniquely identifying a DarwiNN device. +struct Device { + // Device type. + enum class Type { + // PCI device. + // Path format: "/dev/" + // Example: /dev/apex_0 + PCI = 0, + + // USB device. + // Path format: "/sys/bus/usb/devices/-" + // Example: /sys/bus/usb/devices/3-5.1.2.1.2 + USB = 1, + + // Platform (integrated) device. + // Path format: "/dev/" + PLATFORM = 2, + + // Remote PCI device (for testing.) + // Path format: ":" + REMOTE_PCI = 10, + + // Remote USB device (for testing.) + // Path format: ":" + REMOTE_USB = 11, + + // Remote Platform device (for testing.) + // Path format: ":" + REMOTE_PLATFORM = 12, + + // Reference driver (for testing.) + REFERENCE = 30, + + // Simulator driver (for testing.) Path is ignored, + // Chip determines which simulator is instantiated. + SIMULATOR = 31, + }; + + // TODO: Replace map of strings with something better. + using Attributes = std::unordered_map; + + // Chip + Chip chip; + + // Device type. + Type type; + + // String that uniquely identifies the device. + // Set this to DriverFactory::kDefaultDevicePath for default device picked by + // the factory. + std::string path; + + // Device attributes discovered through enumeration. + // The exact set of possible key-value pairs is provider-specific. + Attributes attributes; +}; + +// Returns correct Device::Type for given |device_type|. +Device::Type GetTypeByName(const std::string& device_type); + +// Returns the name of the given |device_type|. +std::string GetTypeName(Device::Type device_type); + +// Enumerates devices and creates drivers for those devices. +class DriverFactory { + public: + static constexpr const char* kDefaultDevicePath = "default"; + + // Creates a singleton driver. + static Driver* CreateDriverAsSingleton(const Device& device, + const Driver::Options& options); + + // Creates or returns the singleton instance of the driver factory. + static DriverFactory* GetOrCreate(); + + virtual ~DriverFactory() = default; + + // Enumerates all available devices. + virtual std::vector Enumerate() = 0; + + // Creates a driver instance that interfaces to the specified device. + virtual util::StatusOr> CreateDriver( + const Device& device) = 0; + + // Creates a driver instance that interfaces to the specified device with + // custom options. + virtual util::StatusOr> CreateDriver( + const Device& device, const Driver::Options& options) = 0; + + protected: + // Constructor. + DriverFactory() = default; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_DRIVER_FACTORY_H_ diff --git a/api/driver_options.fbs b/api/driver_options.fbs new file mode 100644 index 0000000..c095a4d --- /dev/null +++ b/api/driver_options.fbs @@ -0,0 +1,91 @@ +// IDL file for DarwiNN Driver Options. + +namespace platforms.darwinn.api; + +enum PerformanceExpectation : int { + // Clock-rate setting affecting performance. + Low, + Medium, + High, + Max +} + +// USB-specific options +table DriverUsbOptions { + // Path to the DFU firmware. Empty string implies default values(s). + dfu_firmware:string; + + // If true, always performs DFU. Otherwise, only performs DFU if + // device is in DFU mode. + always_dfu:bool = true; + + // If true, driver would fail to open if the current connection is low, + // full, or high speed. If the connection speed is not observable from + // underlying provider, this option is ignored. + has_fail_if_slower_than_superspeed: bool = false; + fail_if_slower_than_superspeed: bool = false; + + // If true, bulk-in data is transmitted in largest chunks possible. + // By default, driver uses 1KB chunk size for USB3 and 256B for USB2. + // This is part of workaround for b/73181174 + has_force_largest_bulk_in_chunk_size: bool = false; + force_largest_bulk_in_chunk_size: bool = false; + + // If true, the fence between bulk-out and bulk-in would be lifted, allowing + // bulk-in to be issued before all bulk-out are finished. This feature could + // improve performance significantly on Android platform. + has_enable_overlapping_bulk_in_and_out: bool = false; + enable_overlapping_bulk_in_and_out: bool = true; + + // If true, multiple bulk-in requests would be issued instead of just one at + // any moment. enable_overlapping_bulk_in_and_out must also be true for + // this feature to be enabled. + has_enable_queued_bulk_in_requests: bool = false; + enable_queued_bulk_in_requests: bool = true; + + // Max number of buffers to queue when enable_queued_bulk_in_requests + // is true. + has_bulk_in_queue_capacity: bool = false; + bulk_in_queue_capacity: int = 32; +} + +table DriverOptions { + // Forwarded compatibility is possible only within the same version number + // by following Flatbuffer schema evolution guidelines. + // https://google.github.io/flatbuffers/md__schemas.html + version:int = 1; + usb:DriverUsbOptions; + + // Debug message verbosity inside the driver. + verbosity:int = 0; + performance_expectation:PerformanceExpectation = High; + + // The public key used for verifying executable packages in PEM format. + public_key:string; + + // The amount of time (in nanoseconds) that runtime waits to hear back from + // TPU (when it expects to) until it decides TPU is stuck and we need to + // perform a reset. A TPU reset results in all pending requests to be + // cancelled. 0 means Unlimited. + watchdog_timeout_ns:long = 0; + + // Operating frequency. Used for computing initial max execution time for + // drivers supporting real-time mode. + // -1 means no operating frequency supplied (for drivers without real-time + // mode support). + tpu_frequency_hz:int64 = -1; + + // The maximum amount of work (in terms of nanoseconds spent on TPU) that can + // be scheduled in the DMA scheduler at any given point in time. -1 means no + // maximum and all tasks get scheduled immediately. Exceptions are: + // 1. P0 requests. + // 2. When a single inference takes longer than this time and there is no + // other task scheduled (avoid starvation). + max_scheduled_work_ns:int64 = -1; + + // Data transfer bandwidth between host and TPU in bytes per second. -1 means + // assume infinite bandwidth. + host_to_tpu_bps:int64 = -1; +} + +root_type DriverOptions; diff --git a/api/driver_options_helper.cc b/api/driver_options_helper.cc new file mode 100644 index 0000000..6330634 --- /dev/null +++ b/api/driver_options_helper.cc @@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/driver_options_helper.h" + +#include "api/driver_options_generated.h" + +namespace platforms { +namespace darwinn { +namespace api { + +constexpr int64 kOperatingFrequency = 1000000LL; +constexpr int64 kHostTpuBps = 1000000000LL; + +// Returns driver options for default. +Driver::Options DriverOptionsHelper::Defaults() { + flatbuffers::FlatBufferBuilder builder; + auto options_offset = api::CreateDriverOptions( + builder, + /*version=*/1, + /*usb=*/0, + /*verbosity=*/0, + /*performance_expectation=*/api::PerformanceExpectation_High, + /*public_key=*/builder.CreateString(""), + /*watchdog_timeout_ns=*/0, + /*tpu_frequency_hz=*/kOperatingFrequency, + /*max_scheduled_work_ns=*/-1, + /*host_to_tpu_bps=*/kHostTpuBps); + builder.Finish(options_offset); + return api::Driver::Options(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); +} + +// Returns driver options for maximum performance. +Driver::Options DriverOptionsHelper::MaxPerformance() { + flatbuffers::FlatBufferBuilder builder; + auto options_offset = api::CreateDriverOptions( + builder, + /*version=*/1, + /*usb=*/0, + /*verbosity=*/0, + /*performance_expectation=*/api::PerformanceExpectation_Max, + /*public_key=*/builder.CreateString(""), + /*watchdog_timeout_ns=*/0, + /*tpu_frequency_hz=*/kOperatingFrequency, + /*max_scheduled_work_ns=*/-1, + /*host_to_tpu_bps=*/kHostTpuBps); + builder.Finish(options_offset); + return api::Driver::Options(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); +} + + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/api/driver_options_helper.h b/api/driver_options_helper.h new file mode 100644 index 0000000..11ad3b9 --- /dev/null +++ b/api/driver_options_helper.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_DRIVER_OPTIONS_HELPER_H_ +#define DARWINN_API_DRIVER_OPTIONS_HELPER_H_ + +#include "api/driver.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// A simpler wrapper around several static helper functions. +class DriverOptionsHelper { + public: + // Returns driver options for default. + static Driver::Options Defaults(); + + // Returns driver options for maximum performance. + static Driver::Options MaxPerformance(); +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_DRIVER_OPTIONS_HELPER_H_ diff --git a/api/execution_context_interface.h b/api/execution_context_interface.h new file mode 100644 index 0000000..91298f3 --- /dev/null +++ b/api/execution_context_interface.h @@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_EXECUTION_CONTEXT_INTERFACE_H_ +#define DARWINN_API_EXECUTION_CONTEXT_INTERFACE_H_ + +namespace platforms { +namespace darwinn { +namespace api { + +// This class stores information related to on-device execution on the TPU. This +// empty base interface may be inherited to store any kind of execution related +// info. Info may include ML model information, process name, etc. This class is +// NOT thread-safe. +class ExecutionContextInterface { + public: + virtual ~ExecutionContextInterface() = default; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_EXECUTION_CONTEXT_INTERFACE_H_ diff --git a/api/layer_information.cc b/api/layer_information.cc new file mode 100644 index 0000000..f6d251c --- /dev/null +++ b/api/layer_information.cc @@ -0,0 +1,457 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/layer_information.h" + +#include +#include + +#include "api/buffer.h" +#include "api/tensor_util.h" +#include "executable/executable_generated.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace api { + +namespace { + +// Performs sanity check for the output shape information. Returns error if +// slice layout information is invalid. +util::Status SanityCheckShapeInformation(const OutputShapeInfo& shape_info, + int data_type_size) { + for (int i = 0; i < shape_info.slice_layout()->size(); ++i) { + // Each slice shape is stored in its own slice layout. Make sure the layout + // is valid. + const auto& slice_layout = *shape_info.slice_layout()->Get(i); + if (!tensor_util::IsValidLayout(slice_layout)) { + return util::FailedPreconditionError( + StringPrintf("Invalid shape for slice %d: %s", i, + tensor_util::DumpLayout(slice_layout).c_str())); + } + const int slice_offset = shape_info.slice_offset()->Get(i); + if (slice_offset % data_type_size != 0) { + return util::FailedPreconditionError(StringPrintf( + "Slice offset [%d] is not aliged to data type size [%d].", + slice_offset, data_type_size)); + } + } + + return util::OkStatus(); +} + +// Copies elements in source shape to dest address. Destination layout is in +// dest_shape. +void CopyShape(const TensorShapeT& source_shape, + const TensorLayout& source_layout, + const unsigned char* source_address, + const TensorLayout& dest_layout, unsigned char* dest_address, + int bytes_per_element, int dimension) { + CHECK_LT(dimension, tensor_util::kNumDimensions); + CHECK_GE(dimension, 0); + + // Source shape can be in non-contiguous memory space if there are z-padding + // elements. + if (tensor_util::IsShapeInContiguousLayout(source_layout, source_shape) && + tensor_util::IsShapeInContiguousLayout(dest_layout, source_shape)) { + const int dest_offset = + tensor_util::GetFirstMemoryIndexForShape(dest_layout, source_shape) * + bytes_per_element; + const int source_offset = + tensor_util::GetFirstMemoryIndexForShape(source_layout, source_shape) * + bytes_per_element; + + // Do mem copy for the shape. + memcpy( + dest_address + dest_offset, source_address + source_offset, + tensor_util::GetNumElementsInShape(source_shape) * bytes_per_element); + } else { + const auto range = source_shape.dimension.at(dimension); + for (int i = range.start(); i <= range.end(); ++i) { + auto slice = source_shape; + slice.dimension.at(dimension) = {i, i}; + CopyShape(slice, source_layout, source_address, dest_layout, dest_address, + bytes_per_element, dimension + 1); + } + } +} + +} // namespace + +LayerInformation::LayerInformation(const Layer* layer) : layer_(layer) { + CHECK(layer != nullptr); +} + +int LayerInformation::DataTypeSize() const { + return TensorDataTypeSize(layer()->data_type()); +} + +bool LayerInformation::SignedDataType() const { + switch (layer()->data_type()) { + case DataType_SIGNED_FIXED_POINT8: + case DataType_SIGNED_FIXED_POINT16: + return true; + + case DataType_FIXED_POINT8: + case DataType_FIXED_POINT16: + // TODO: DataType_SIGNED_FIXED_POINT32 (previously + // DataType_FIXED_POINT32) is a signed number, see b/135944737. + // However, the function returns false, which looks like a bug. Please + // confirm it. + case DataType_SIGNED_FIXED_POINT32: + case DataType_BFLOAT: + case DataType_HALF: + case DataType_SINGLE: + return false; + } +} + +util::Status LayerInformation::TransformSignedDataType(Buffer buffer) const { + const auto data_type_size = DataTypeSize(); + if (buffer.size_bytes() < ActualSizeBytes()) { + return util::InvalidArgumentError(StringPrintf( + "Provided buffer size (%zu) is less than actual size_bytes (%d).", + buffer.size_bytes(), ActualSizeBytes())); + } + auto buffer_pointer = buffer.ptr(); + int buffer_index = 0; + + for (int y = 0; y < y_dim(); ++y) { + for (int x = 0; x < x_dim(); ++x) { + for (int z = 0; z < z_dim(); ++z) { + // XORing with 128 on the last byte of each entry will flip the MSB of + // each entry. Please note that bytes are stored little endian. + int msb_index = buffer_index + data_type_size - 1; + buffer_pointer[msb_index] = buffer_pointer[msb_index] ^ 128; + buffer_index += data_type_size; + } + } + } + + return util::OkStatus(); +} + +InputLayerInformation::InputLayerInformation(const Layer* layer) + : LayerInformation(layer) {} + +OutputLayerInformation::OutputLayerInformation(const Layer* layer) + : LayerInformation(layer), + output_layer_(layer->any_layer_as_OutputLayer()) { + CHECK(output_layer_ != nullptr); +} + +OutputLayerInformation::YBufferIndex OutputLayerInformation::GetYBufferIndex( + int y) const { + const auto& layout = output_layer_->layout(); + YBufferIndex output; + output.y_linearized_tile_id = + layout->y_coordinate_to_linear_tile_id_map()->Get(y); + output.local_y_coordinate = layout->y_coordinate_to_local_y_offset()->Get(y); + return output; +} + +int OutputLayerInformation::GetBufferIndex(const YBufferIndex& y_buffer_index, + int x, int z) const { + const auto& layout = output_layer_->layout(); + const int linear_tile_id = + y_buffer_index.y_linearized_tile_id + + layout->x_coordinate_to_linear_tile_id_map()->Get(x); + const int global_tile_byte_offset = + layout->linearized_tile_byte_offset()->Get(linear_tile_id); + + const int local_x_byte_offset = + layout->x_coordinate_to_local_byte_offset()->Get(x); + const int local_y_byte_offset = + y_buffer_index.local_y_coordinate * + layout->x_coordinate_to_local_y_row_size()->Get(x); + + return global_tile_byte_offset + local_y_byte_offset + local_x_byte_offset + + z; +} + +int OutputLayerInformation::GetBufferIndex(int y, int x, int z) const { + return GetBufferIndex(GetYBufferIndex(y), x, z); +} + +bool OutputLayerInformation::NeedsRelayout() const { + if (!output_layer_->shape_info()) { + return true; + } + // Relayout is not needed when output layout has only one shape with no + // padding between elements. + const auto& slice_layouts = *output_layer_->shape_info()->slice_layout(); + return slice_layouts.size() > 1 || + !tensor_util::IsNoPaddingLayout(*slice_layouts.Get(0)); +} + +// TODO Add unit tests for this method. +util::Status OutputLayerInformation::Relayout(unsigned char* dest, + const unsigned char* src) const { + // TODO: re-use the same buffer and avoid an unnecessary memcopy + // when relayout is not needed. + if (!NeedsRelayout()) { + memcpy(dest, src, + batch_dim() * y_dim() * x_dim() * z_dim() * DataTypeSize()); + return util::OkStatus(); + } + + if (output_layer_->shape_info()) { + // If output shape info exists in the executable, use the new re-layout + // function. Currently, this is only enabled for models with multiple + // batches. + return RelayoutWithShapeInformation(dest, src); + } + + const auto data_type_size = DataTypeSize(); + const int z_bytes = z_dim() * data_type_size; + + if (y_dim() == 1 && x_dim() == 1) { + // One dimensional output (only z-dimension). + if (src != dest) { + const int padded_size_bytes = PaddedSizeBytes(); + const int actual_size_bytes = ActualSizeBytes(); + const int executions = execution_count_per_inference(); + if (executions == 1 || padded_size_bytes == actual_size_bytes) { + memcpy(dest, src, z_bytes * executions); + } else { + // Remove padding values at the end of each execution. + const int padded_size_per_execution = + (padded_size_bytes - actual_size_bytes) / executions; + for (int i = 0; i < executions; ++i) { + memcpy(dest, src, z_bytes); + dest += z_bytes; + src += z_bytes + padded_size_per_execution; + } + } + } + } else { + int z_bytes_padded; + if (x_dim() > 1) { + // If x-dim is > 1, padded-z-size can be deduced by looking at difference + // between offset of element y=0,x=0,z=0 and y=0,x=1,z=0. + z_bytes_padded = GetBufferIndex(0, 1, 0) - GetBufferIndex(0, 0, 0); + } else { + // Otherwise when x-dim is 1 (y-dim must be > 1 in that case), + // padded-z-size can be deduced by looking at difference between offset of + // element y=0,x=0,z=0 and y=1,x=0,z=0. + z_bytes_padded = GetBufferIndex(1, 0, 0) - GetBufferIndex(0, 0, 0); + } + z_bytes_padded *= data_type_size; + + const auto* layout = output_layer_->layout(); + int last_x = 0; + int last_x_tile = layout->x_coordinate_to_linear_tile_id_map()->Get(0); + std::vector active_tile_x_sizes; + for (int x = 1; x < x_dim(); ++x) { + int cur_x_tile = layout->x_coordinate_to_linear_tile_id_map()->Get(x); + if (cur_x_tile != last_x_tile) { + active_tile_x_sizes.push_back(x - last_x); + last_x_tile = cur_x_tile; + last_x = x; + } + } + active_tile_x_sizes.push_back(x_dim() - last_x); + +// When the num_z_bytes parameter is a compile-time constant, the conditions +// in the innermost loop will be replaced with a single optimized path, +// specialized for that value. +// Specialization is provided for num_z_bytes value of 1 and 3. +// We can also make this a helper function and still realize the benefits +// provided we have a guaranteed way of ensuring this function would be inlined +// so that the compiler optimizations based on compile-time-constants can kick +// in. +#define RELAYOUT_WITH_Z_BYTES_SPECIALIZATION( \ + num_z_bytes, num_z_bytes_padded) \ + do { \ + for (int y = 0; y < y_dim(); ++y) { \ + const auto y_buffer_index = GetYBufferIndex(y); \ + int tile_starting_x = 0; \ + for (int x_tile = 0; x_tile < active_tile_x_sizes.size(); ++x_tile) { \ + const unsigned char* source = \ + src + GetBufferIndex(y_buffer_index, tile_starting_x, /*z=*/0) * \ + data_type_size; \ + const int tile_x_size = active_tile_x_sizes[x_tile]; \ + for (int local_offset_x = 0; local_offset_x < tile_x_size; \ + ++local_offset_x) { \ + if ((num_z_bytes) == 1) { \ + *dest = *source; \ + } else if ((num_z_bytes) == 3) { \ + *(dest + 0) = *(source + 0); \ + *(dest + 1) = *(source + 1); \ + *(dest + 2) = *(source + 2); \ + } else { \ + memcpy(dest, source, (num_z_bytes)); \ + } \ + dest += (num_z_bytes); \ + source += (num_z_bytes_padded); \ + } \ + tile_starting_x += tile_x_size; \ + } \ + } \ + } while (0) + + if (z_bytes != z_bytes_padded) { + if (z_bytes == 1) { + // Specialization for z_bytes = 1 (grayscale image). + RELAYOUT_WITH_Z_BYTES_SPECIALIZATION(1, 4); + } else if (z_bytes == 3) { + // Specialization for z_bytes = 3 (RGB image). + RELAYOUT_WITH_Z_BYTES_SPECIALIZATION(3, 4); + } else { + // Default. + RELAYOUT_WITH_Z_BYTES_SPECIALIZATION(z_bytes, z_bytes_padded); + } + } else { + // TODO: test models impacted with this relayout method. + const int first_y_tile = + layout->y_coordinate_to_linear_tile_id_map()->Get(0); + const int last_y_tile = + layout->y_coordinate_to_linear_tile_id_map()->Get(y_dim() - 1); + + // If there's only one output shape from one tile and no Z padding, copy + // data directly. + const bool need_relayout = + active_tile_x_sizes.size() > 1 || first_y_tile != last_y_tile; + + if (need_relayout) { + // If there's no z padding, copy one xz block on one tile at a time. + for (int y = 0; y < y_dim(); ++y) { + const auto y_buffer_index = GetYBufferIndex(y); + int tile_starting_x = 0; + for (int x_tile = 0; x_tile < active_tile_x_sizes.size(); ++x_tile) { + const unsigned char* source = + src + GetBufferIndex(y_buffer_index, tile_starting_x, /*z=*/0) * + data_type_size; + const int tile_x_size = active_tile_x_sizes[x_tile]; + const int tile_x_z_bytes = z_bytes * tile_x_size; + memcpy(dest, source, tile_x_z_bytes); + dest += tile_x_z_bytes; + tile_starting_x += tile_x_size; + } + } + } else { + // TODO: avoid copy and assign in caller directly. + memcpy(dest, src, x_dim() * y_dim() * z_bytes); + } + } + +#undef RELAYOUT_WITH_Z_BYTES_SPECIALIZATION + + // TODO: If iteration count is more than 1, we need to make sure we + // advance 'src' and 'dest' correctly due to padding issue. We don't have + // test case now. + CHECK_EQ(execution_count_per_inference(), 1) + << "Verification is missing if execution count is greater than 1"; + } + + return util::OkStatus(); +} + +util::Status OutputLayerInformation::RelayoutWithShapeInformation( + unsigned char* dest, const unsigned char* src) const { + CHECK_EQ(execution_count_per_inference(), 1) + << "Multiple inference execution not supported in the new relayout " + "(b/129301507)."; + + const auto data_type_size = DataTypeSize(); + const auto& shape_info = *output_layer_->shape_info(); + + RETURN_IF_ERROR(SanityCheckShapeInformation(shape_info, data_type_size)); + + flatbuffers::FlatBufferBuilder builder; + const auto fb_layout = tensor_util::BuildPackedLayout(*layer()->shape()); + builder.Finish(darwinn::TensorLayout::Pack(builder, fb_layout.get())); + const TensorLayout& dest_layout = + *flatbuffers::GetRoot(builder.GetBufferPointer()); + unsigned char* dest_address = dest; + + const auto& slice_layouts = *shape_info.slice_layout(); + for (int i = 0; i < slice_layouts.size(); ++i) { + // Each slice is stored in a contiguous memory space. + const TensorLayout& source_layout = *slice_layouts.Get(i); + TensorShapeT source_shape; + source_layout.shape()->UnPackTo(&source_shape); + const unsigned char* source_address = + src + shape_info.slice_offset()->Get(i); + + CopyShape(source_shape, source_layout, source_address, dest_layout, + dest_address, data_type_size, tensor_util::kBatch); + } + + return util::OkStatus(); +} + +int OutputLayerInformation::GetBufferIndex( + const std::vector& element_position) const { + const auto* shape_info = output_layer_->shape_info(); + if (!shape_info) { + CHECK_EQ(element_position.size(), 4); + CHECK_EQ(element_position[tensor_util::kBatch], 0); + return GetBufferIndex(/*y=*/element_position[tensor_util::kY], + /*x=*/element_position[tensor_util::kX], + /*z=*/element_position[tensor_util::kZ]); + } + + const int data_type_size = DataTypeSize(); + const auto& slice_layouts = *shape_info->slice_layout(); + for (int i = 0; i < slice_layouts.size(); ++i) { + const TensorLayout& slice_layout = *slice_layouts.Get(i); + const TensorShape& slice_shape = *slice_layout.shape(); + if (tensor_util::IsElementInShape(slice_shape, element_position)) { + const int index = tensor_util::GetMemoryIndexFromPosition( + slice_layout, element_position); + const int slice_base_offset_in_bytes = shape_info->slice_offset()->Get(i); + CHECK_EQ(slice_base_offset_in_bytes % data_type_size, 0); + const int slice_base_offset_in_elements = + slice_base_offset_in_bytes / data_type_size; + return slice_base_offset_in_elements + index; + } + } + + std::string position_string; + for (int index : element_position) { + position_string += StringPrintf("[%d]", index); + } + + LOG(FATAL) << "Cannot find element in output: " << position_string; + return 0; +} + +int TensorDataTypeSize(DataType data_type) { + switch (data_type) { + case DataType_FIXED_POINT8: + case DataType_SIGNED_FIXED_POINT8: + return 1; + case DataType_FIXED_POINT16: + case DataType_SIGNED_FIXED_POINT16: + return 2; + case DataType_SIGNED_FIXED_POINT32: + return 4; + case DataType_BFLOAT: + return 2; + case DataType_HALF: + return 2; + case DataType_SINGLE: + return 4; + } +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/api/layer_information.h b/api/layer_information.h new file mode 100644 index 0000000..168fe1f --- /dev/null +++ b/api/layer_information.h @@ -0,0 +1,177 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_LAYER_INFORMATION_H_ +#define DARWINN_API_LAYER_INFORMATION_H_ + +#include + +#include "api/buffer.h" +#include "api/tensor_util.h" +#include "executable/executable_generated.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// Provides information on input and output layers. +class LayerInformation { + public: + virtual ~LayerInformation() = default; + + // Copyable. + LayerInformation(const LayerInformation& rhs) = default; + LayerInformation& operator=(const LayerInformation& rhs) = default; + + // Returns layer name. + std::string name() const { return layer_->name()->str(); } + + // Layer dimensions. + int x_dim() const { return layer_->x_dim(); } + int y_dim() const { return layer_->y_dim(); } + int z_dim() const { return layer_->z_dim(); } + int batch_dim() const { + if (layer_->shape()) { + return tensor_util::GetDimensionLength(*layer_->shape(), 0); + } else { + return 1; + } + } + + // Returns the zero point value. + int zero_point() const { return layer_->numerics()->zero_point(); } + + // Returns the execution count per inference. + int execution_count_per_inference() const { + return layer_->execution_count_per_inference(); + } + + // Returns the dequantization factor. + float dequantization_factor() const { + return layer_->numerics()->dequantization_factor(); + } + + // Returns data type in this layer. + darwinn::DataType data_type() const { return layer_->data_type(); } + + // Returns the size of the data type in this layer in bytes. + int DataTypeSize() const; + + // Returns true if the data type is signed. + bool SignedDataType() const; + + // Returns the expected byte size of activations for given layer. This + // excludes padding. + int ActualSizeBytes() const { + const int num_elements = + (layer_->shape()) ? tensor_util::GetNumElementsInShape(*layer_->shape()) + : x_dim() * y_dim() * z_dim(); + return num_elements * DataTypeSize() * + layer_->execution_count_per_inference(); + } + + // Returns the expected size of activations for given layer including padding + // bytes. + int PaddedSizeBytes() const { + return SizeBytesPerIteration() * layer_->execution_count_per_inference(); + } + + int SizeBytesPerIteration() const { return layer_->size_bytes(); } + + // Returns true if activations of this input/output layer need to be cached on + // DRAM. + bool CacheOnDram() const { return layer_->cache_on_dram(); } + + // Converts unsigned values for a provided buffer of this layer to signed and + // vice versa. + util::Status TransformSignedDataType(Buffer buffer) const; + + protected: + explicit LayerInformation(const Layer* layer); + + const Layer* layer() const { return layer_; } + + private: + const Layer* layer_; +}; + +// Provides detailed information on input layers. +class InputLayerInformation : public LayerInformation { + public: + explicit InputLayerInformation(const Layer* layer); + ~InputLayerInformation() override = default; +}; + +// Provides detailed information on output layers. +class OutputLayerInformation : public LayerInformation { + public: + // Holds y-dependent values that are needed to calculate buffer index. + // Expected usage is as follows: + // + // for (int y = 0; y < y_dim(); ++y) { + // const YBufferIndex y_buffer_index = GetYBufferIndex(y); + // for (int x = 0; x < x_dim(); ++x) { + // const int src_offset = + // GetBufferIndex(y_buffer_index, x, /*z=*/0) * data_type_size; + // ... + // } + // } + struct YBufferIndex { + // Holds the linearized tile ID for a given y value. + int y_linearized_tile_id; + // Holds local offset within a data chunk returned by a given tile. + int local_y_coordinate; + }; + + explicit OutputLayerInformation(const Layer* layer); + ~OutputLayerInformation() override = default; + + // Returns an index value of output buffer for a given tensor coordinate. + int GetBufferIndex(int y, int x, int z) const; + int GetBufferIndex(const std::vector& element_position) const; + + // Essentially does the same thing as functions above, but these two split + // functions allow a user to save some computation such as using the samve + // value from GetYBufferIndex across all x pixels. + YBufferIndex GetYBufferIndex(int y) const; + int GetBufferIndex(const YBufferIndex& y_buffer_index, int x, int z) const; + + // Relayout the source DarwiNN output buffer (TYXZ layout, T = Tile) into + // user output buffer (YXZ layout). + // + // TODO Move this method down to driver internal classes once all + // dependencies are removed. + util::Status Relayout(unsigned char* dest, const unsigned char* src) const; + + // Returns true if relayout is needed. + bool NeedsRelayout() const; + + private: + // Re-layouts the output activation stream from the tiles into a desired + // format in the host memory. + util::Status RelayoutWithShapeInformation(unsigned char* dest, + const unsigned char* src) const; + + const OutputLayer* output_layer_; +}; + +// Returns the byte size of a provided tensor data type. +int TensorDataTypeSize(DataType data_type); + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_LAYER_INFORMATION_H_ diff --git a/api/package_reference.h b/api/package_reference.h new file mode 100644 index 0000000..b249be0 --- /dev/null +++ b/api/package_reference.h @@ -0,0 +1,156 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_PACKAGE_REFERENCE_H_ +#define DARWINN_API_PACKAGE_REFERENCE_H_ + +#include +#include +#include +#include + +#include "api/execution_context_interface.h" +#include "api/layer_information.h" +#include "executable/executable_generated.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// Specifies the most recent package identifier for executable.fbs. +constexpr const char* kHeadPackageIdentifier = "DWN1"; + +// Type for a registered executable. +class PackageReference { + public: + virtual ~PackageReference() = default; + + // This class is neither copyable nor movable. + PackageReference(const PackageReference&) = delete; + PackageReference& operator=(const PackageReference&) = delete; + + // Verifies the digital signature of the backing executable package. + virtual util::Status VerifySignature() const = 0; + + // Returns the index of input layer with given name. + virtual util::StatusOr InputIndex(const std::string& name) const = 0; + + // Returns the index of output layer with given name. + virtual util::StatusOr OutputIndex(const std::string& name) const = 0; + + // Returns number of input layers. + virtual int NumInputLayers() const = 0; + + // Returns number of output layers. + virtual int NumOutputLayers() const = 0; + + // Returns list of input layer names. + virtual const std::vector& InputLayerNames() const = 0; + + // Returns list of output layer names. + virtual const std::vector& OutputLayerNames() const = 0; + + // Returns information on given input layer. Returns nullptr if index is out + // of bounds. + virtual const InputLayerInformation* InputLayer(int index) const = 0; + + // Returns information on given output layer. Returns nullptr if index is out + // of bounds. + virtual const OutputLayerInformation* OutputLayer(int index) const = 0; + + // Returns information on given input layer. + virtual util::StatusOr InputLayer( + const std::string& layer_name) const = 0; + + // Returns information on given output layer. + virtual util::StatusOr OutputLayer( + const std::string& layer_name) const = 0; + + // Returns the expected byte size of activations for given input layer index. + virtual int InputLayerSizeBytes(int index) const = 0; + + // Returns the expected byte size of activations for given input layer index. + // This is post-padding, if any. + // TODO Remove this method. + virtual int InputLayerPaddedSizeBytes(int index) const = 0; + + // Returns the expected byte size of activations for given output layer index. + virtual int OutputLayerSizeBytes(int index) const = 0; + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + virtual int InputLayerSize(int index) const = 0; + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + virtual int OutputLayerSize(int index) const = 0; + + // Returns the expected size of activations for given input layer. + // Prefer index based APIs for performance. + virtual util::StatusOr InputLayerSizeBytes( + const std::string& name) const = 0; + + // Returns the expected size of activations for given input layer including + // padding bytes. + // Prefer index based APIs for performance. + // TODO Remove this method. + virtual util::StatusOr InputLayerPaddedSizeBytes( + const std::string& name) const = 0; + + // Returns the expected size of activations for given output layer. + // Prefer index based APIs for performance. + virtual util::StatusOr OutputLayerSizeBytes( + const std::string& name) const = 0; + + // Returns name for given input layer index. + virtual std::string InputLayerName(int index) const = 0; + + // Returns name for given output layer index. + virtual std::string OutputLayerName(int index) const = 0; + + // Returns batch size. + virtual int BatchSize() const = 0; + + // Sets the execution context (info related to execution). The execution + // context is later used for logging purposes. + virtual void SetExecutionContextInterface( + std::unique_ptr + execution_context_interface) = 0; + + // Sets the maximum amount of time this package can tolerate for an inference + // to finish. Setting this will make driver check if it can meet the latency + // target on each inference. If it cannot, it will immediately return a + // deadline exceeded error. Parameter-caching or anything extra that driver + // needs to run in order to complete an inference will be counted towards this + // target. If a batch request is submitted, the total time to complete the + // batch is counted (not a single batch element). + virtual util::Status SetLatencyTolerance(int64 max_latency_ms) = 0; + + // Returns a unique user-specified string identifies the model. It returns + // empty string if no identifier is set. This is available for limited cases + // only. + virtual std::string ModelIdentifier() const = 0; + + protected: + PackageReference() = default; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_PACKAGE_REFERENCE_H_ + diff --git a/api/request.h b/api/request.h new file mode 100644 index 0000000..ab9b7a7 --- /dev/null +++ b/api/request.h @@ -0,0 +1,136 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_REQUEST_H_ +#define DARWINN_API_REQUEST_H_ + +#include +#include +#include + +#include "api/buffer.h" +#include "port/integral_types.h" +#include "port/status_macros.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// Compute request. Thread-unsafe. +class Request { + public: + // A type for request completion callback. + // The int argument is the same as return value of id(). + using Done = std::function; + + // Fine grain timing information + struct TimingEvent { + // Classify each TPU Request (sub-requests) for logging. + enum class TpuRequestType { + PARAMETER_CACHING, // Request for parameter caching. + INFERENCE // Inference request, single hardware batch. + }; + + // Classify the TimingEvents based on what is happening to the TPU Request. + enum class EventType { + SUBMITTED, // The sub-request was submitted. + COMPLETED // The sub-request was completed. + }; + + int64 timestamp; // When the event occurred. + TpuRequestType request_type; // Request classification for logging. + EventType event_type; // What happened (request creation, completion). + + // In DarwiNN 1.0, requests are sent in order. If that changes in the + // future, need to add a request_id to correlate events belonging to a + // single request, while multiple requests are in flight. + + TimingEvent(int64 timestamp, TpuRequestType type, EventType state) + : timestamp(timestamp), + request_type(type), + event_type(state){} + }; + + // Encapsulates timing information of a request. + struct Timing { + // Timestamp (in nanoseconds) of when the request was first created. + int64 created_ns; + + // Timestamp (in nanoseconds) of when the request was submitted to the + // device for execution. In case of batched requests, this is the time when + // the first batch element is submitted. + int64 submitted_ns; + + // Timestamp (in nanoseconds) of when the request was completed in hardware. + // In case of batched requests, this is the time that the last batch element + // completed execution. + int64 completed_ns; + + // Capture finegrain event timestamps for each single_tpu_request + std::vector detail_timing; + }; + + Request() = default; + virtual ~Request() = default; + + // This class is neither copyable nor movable. + Request(const Request&) = delete; + Request& operator=(const Request&) = delete; + + // Adds an input buffer. This may be called repeatedly depending + // on the batch size as long as the request instance is not submitted. The + // size constraints on the input and output buffers will be evaluated during + // Device#Submit. Memory backing the buffer instance must be valid throughout + // the life of the request. + // IMPORTANT: For better performance, please make sure input buffers are + // aligned with at least minimum_alignment_bytes (architecture dependent). If + // possible use Driver::MakeBuffer to get a buffer with this requirement met. + // Buffers with and without padding are both acceptable. + virtual util::Status AddInput(const std::string& name, + const Buffer& input) = 0; + + // Adds an output buffer. This may be called repeatedly depending + // on the batch size as long as the request instance is not submitted. The + // size constraints on the input and output buffers will be evaluated during + // Device#Submit. Memory backing the buffer instance must be valid throughout + // the life of the request. + // + // If the output buffer is user-allocated on-device DRAM, the model must + // ensure that no post-processing will be needed for this output, such as + // re-layout or sign processing. + // TODO -- the API implementation does not currently validate + // that no post-processing will be needed for a user-allocated on-device DRAM + // output. + virtual util::Status AddOutput(const std::string& name, Buffer output) = 0; + + // Sets the scheduling priority of this request (must be a positive int) where + // 0 is highest priority. P0 requests are immediately scheduled for execution + // while lower priorities (higher in value) may get preempted if device is + // busy. By default, a request is P0. + virtual util::Status SetPriority(int priority) = 0; + + // Returns timing information of this request. It can only be called when the + // request is done. + virtual util::StatusOr GetTiming() const = 0; + + // Returns an ID to track the request. + virtual int id() const = 0; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_REQUEST_H_ diff --git a/api/runtime_version.h b/api/runtime_version.h new file mode 100644 index 0000000..1463268 --- /dev/null +++ b/api/runtime_version.h @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_RUNTIME_VERSION_H_ +#define DARWINN_API_RUNTIME_VERSION_H_ + +namespace platforms { +namespace darwinn { +namespace api { + +enum RuntimeVersion { + // A DariwiNN package carrying runtime version less than or equal to this + // value would trigger a warning from driver, during model registration. + // This is the lower bound of kCurrent, and shall not be modified. + kMinValidRuntimeVersion = 10, + + // Increase this number everytime a change involving both compiler and runtime + // happens. This number is used in binary compatibility checks for DarwiNN + // packages. + kCurrent = 13, + // This is the runtime version that has native batch support. + kWithNativeBatchSupport = 11, + // This is the runtime version that has support for int8 as host data type. + kWithInt8HostDataTypeSupport = 12, + // This is the runtime version that has support for 16-bit floating point. + kWithFp16Support = 13, +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_RUNTIME_VERSION_H_ diff --git a/api/telemeter_interface.h b/api/telemeter_interface.h new file mode 100644 index 0000000..07d4f89 --- /dev/null +++ b/api/telemeter_interface.h @@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_TELEMETER_INTERFACE_H_ +#define DARWINN_API_TELEMETER_INTERFACE_H_ + +#include + +#include "api/execution_context_interface.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// This class collects data related to on-device execution on the TPU, and may +// report them back to a server for further analysis. The data could be +// related to performance or execution failure. +// This class is thread-safe. +class TelemeterInterface { + public: + virtual ~TelemeterInterface() = default; + + // Logs the watchdog timeout event. + virtual void LogWatchdogTimeout(const ExecutionContextInterface& context) = 0; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_TELEMETER_INTERFACE_H_ diff --git a/api/tensor_util.cc b/api/tensor_util.cc new file mode 100644 index 0000000..d19ef55 --- /dev/null +++ b/api/tensor_util.cc @@ -0,0 +1,325 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/tensor_util.h" + +#include "port/logging.h" +#include "port/string_util.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace api { +namespace tensor_util { + +TensorShapeT MakeTensorShape(const std::vector& dimensions) { + TensorShapeT shape; + shape.dimension.resize(dimensions.size()); + for (int i = 0; i < dimensions.size(); ++i) { + shape.dimension[i] = {0, dimensions[i] - 1}; + } + return shape; +} + +TensorShapeT MakeTensorShape(const std::vector& ranges) { + TensorShapeT shape; + shape.dimension.resize(ranges.size()); + for (int i = 0; i < ranges.size(); ++i) { + shape.dimension[i] = ranges[i]; + } + return shape; +} + +TensorShapeT GetIntersectShape(const TensorShapeT& one, + const TensorShapeT& two) { + CHECK_EQ(one.dimension.size(), two.dimension.size()); + TensorShapeT intersect; + intersect.dimension.resize(one.dimension.size()); + for (int i = 0; i < one.dimension.size(); ++i) { + intersect.dimension[i] = { + std::max(one.dimension[i].start(), two.dimension[i].start()), + std::min(one.dimension[i].end(), two.dimension[i].end())}; + } + return intersect; +} + +bool IsValidShape(const TensorShape& shape) { + if (shape.dimension()->size() == 0) { + return false; + } + + for (int i = 0; i < shape.dimension()->size(); ++i) { + if (shape.dimension()->Get(i)->start() > shape.dimension()->Get(i)->end()) { + return false; + } + } + return true; +} + +bool IsValidShape(const TensorShapeT& shape) { + if (shape.dimension.empty()) { + return false; + } + + for (int i = 0; i < shape.dimension.size(); ++i) { + if (shape.dimension[i].start() > shape.dimension[i].end()) { + return false; + } + } + return true; +} + +int GetNumElementsInShape(const TensorShape& shape) { + int num_elements = 1; + for (int i = 0; i < shape.dimension()->size(); ++i) { + const int length = shape.dimension()->Get(i)->end() - + shape.dimension()->Get(i)->start() + 1; + CHECK_GT(length, 0); + num_elements *= length; + } + return num_elements; +} + +int GetNumElementsInShape(const TensorShapeT& shape) { + int num_elements = 1; + for (int i = 0; i < shape.dimension.size(); ++i) { + const int length = + shape.dimension[i].end() - shape.dimension[i].start() + 1; + CHECK_GT(length, 0); + num_elements *= length; + } + return num_elements; +} + +int GetDimensionLength(const TensorShape& shape, int dimension) { + return shape.dimension()->Get(dimension)->end() - + shape.dimension()->Get(dimension)->start() + 1; +} + +int GetDimensionLength(const TensorShapeT& shape, int dimension) { + return shape.dimension.at(dimension).end() - + shape.dimension.at(dimension).start() + 1; +} + +bool IsElementInShape(const TensorShape& shape, + const std::vector& position) { + CHECK_EQ(position.size(), shape.dimension()->size()); + for (int i = 0; i < shape.dimension()->size(); ++i) { + const auto range = shape.dimension()->Get(i); + if (position[i] < range->start() || position[i] > range->end()) { + return false; + } + } + return true; +} + +bool IsElementInShape(const TensorShapeT& shape, + const std::vector& position) { + CHECK_EQ(position.size(), shape.dimension.size()); + for (int i = 0; i < shape.dimension.size(); ++i) { + const auto& range = shape.dimension[i]; + if (position[i] < range.start() || position[i] > range.end()) { + return false; + } + } + return true; +} + +std::unique_ptr BuildPackedLayout(const TensorShape& shape) { + auto layout = gtl::MakeUnique(); + + // Fill shape information. + layout->shape = gtl::MakeUnique(); + shape.UnPackTo(layout->shape.get()); + + // Fill stride information. + layout->stride.resize(layout->shape->dimension.size()); + int current_stride = 1; + for (int i = layout->shape->dimension.size() - 1; i >= 0; --i) { + layout->stride[i] = current_stride; + current_stride *= GetDimensionLength(*layout->shape, i); + } + + return layout; +} + +std::unique_ptr BuildPackedLayout(const TensorShapeT& shape) { + auto layout = gtl::MakeUnique(); + + // Fill shape information. + layout->shape = gtl::MakeUnique(shape); + + // Fill stride information. + layout->stride.resize(layout->shape->dimension.size()); + int current_stride = 1; + for (int i = layout->shape->dimension.size() - 1; i >= 0; --i) { + layout->stride[i] = current_stride; + current_stride *= GetDimensionLength(*layout->shape, i); + } + + return layout; +} + +bool IsValidLayout(const TensorLayout& layout) { + const auto& shape = *layout.shape(); + if (!IsValidShape(shape)) { + return false; + } + + for (int i = 0; i < shape.dimension()->size() - 1; ++i) { + if (layout.stride()->Get(i) < + layout.stride()->Get(i + 1) * GetDimensionLength(shape, i + 1)) { + return false; + } + } + + return true; +} + +bool IsValidLayout(const TensorLayoutT& layout) { + const auto& shape = *layout.shape; + if (!IsValidShape(shape)) { + return false; + } + + for (int i = 0; i < shape.dimension.size() - 1; ++i) { + if (layout.stride[i] < + layout.stride[i + 1] * GetDimensionLength(shape, i + 1)) { + return false; + } + } + + return true; +} + +bool IsNoPaddingLayout(const TensorLayout& layout) { + CHECK(IsValidLayout(layout)); + const auto& shape = *layout.shape(); + + // There's no padding in layout if the stride equals to the dimension size. + for (int i = 0; i < shape.dimension()->size() - 1; ++i) { + if (layout.stride()->Get(i) != + layout.stride()->Get(i + 1) * GetDimensionLength(shape, i + 1)) { + return false; + } + } + return true; +} + +int GetLayoutSizeInElements(const TensorLayout& layout) { + CHECK(IsValidLayout(layout)); + return GetDimensionLength(*layout.shape(), 0) * layout.stride()->Get(0); +} + +int GetLayoutSizeInElements(const TensorLayoutT& layout) { + CHECK(IsValidLayout(layout)) << DumpLayout(layout); + return GetDimensionLength(*layout.shape, 0) * layout.stride[0]; +} + +int GetMemoryIndexFromPosition(const TensorLayout& layout, + const std::vector& position) { + CHECK(IsElementInShape(*layout.shape(), position)); + int memory_index = 0; + for (int i = 0; i < position.size(); ++i) { + const int min_index = layout.shape()->dimension()->Get(i)->start(); + const int stride = layout.stride()->Get(i); + memory_index += stride * (position[i] - min_index); + } + return memory_index; +} + +int GetMemoryIndexFromPosition(const TensorLayoutT& layout, + const std::vector& position) { + CHECK(IsElementInShape(*layout.shape, position)); + int memory_index = 0; + for (int i = 0; i < position.size(); ++i) { + const int min_index = layout.shape->dimension.at(i).start(); + const int stride = layout.stride.at(i); + memory_index += stride * (position[i] - min_index); + } + return memory_index; +} + +int GetFirstMemoryIndexForShape(const TensorLayout& layout, + const TensorShapeT& shape) { + std::vector position(shape.dimension.size()); + for (int i = 0; i < shape.dimension.size(); ++i) { + position[i] = shape.dimension[i].start(); + } + return GetMemoryIndexFromPosition(layout, position); +} + +int GetLastMemoryIndexForShape(const TensorLayout& layout, + const TensorShapeT& shape) { + std::vector position(shape.dimension.size()); + for (int i = 0; i < shape.dimension.size(); ++i) { + position[i] = shape.dimension[i].end(); + } + return GetMemoryIndexFromPosition(layout, position); +} + +bool IsShapeInContiguousLayout(const TensorLayout& layout, + const TensorShapeT& shape) { + const int first_index = GetFirstMemoryIndexForShape(layout, shape); + const int last_index = GetLastMemoryIndexForShape(layout, shape); + return GetNumElementsInShape(shape) == (last_index - first_index + 1); +} + +std::string DumpShape(const TensorShape& shape) { + std::string str; + for (int i = 0; i < shape.dimension()->size(); ++i) { + const auto range = shape.dimension()->Get(i); + StrAppend(&str, StringPrintf("[%d:%d]", range->start(), range->end())); + } + return str; +} + +std::string DumpShape(const TensorShapeT& shape) { + std::string str; + for (int i = 0; i < shape.dimension.size(); ++i) { + StrAppend(&str, StringPrintf("[%d:%d]", shape.dimension[i].start(), + shape.dimension[i].end())); + } + return str; +} + +std::string DumpLayout(const TensorLayout& layout) { + std::string str = + StringPrintf("shape=%s", DumpShape(*layout.shape()).c_str()); + StrAppend(&str, ",stride="); + for (int i = 0; i < layout.stride()->size(); ++i) { + if (i > 0) { + StrAppend(&str, "/"); + } + StrAppend(&str, StringPrintf("%d", layout.stride()->Get(i))); + } + return str; +} + +std::string DumpLayout(const TensorLayoutT& layout) { + std::string str = StringPrintf("shape=%s", DumpShape(*layout.shape).c_str()); + StrAppend(&str, ",stride="); + for (int i = 0; i < layout.stride.size(); ++i) { + if (i > 0) { + StrAppend(&str, "/"); + } + StrAppend(&str, StringPrintf("%d", layout.stride[i])); + } + return str; +} + +} // namespace tensor_util +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/api/tensor_util.h b/api/tensor_util.h new file mode 100644 index 0000000..8f36436 --- /dev/null +++ b/api/tensor_util.h @@ -0,0 +1,116 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_TENSOR_UTIL_H_ +#define DARWINN_API_TENSOR_UTIL_H_ + +#include +#include + +#include "executable/executable_generated.h" +#include "port/logging.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace api { +namespace tensor_util { + +// Enum for Tensor shape dimension index. +enum ShapeDimension { + kBatch = 0, + kY = 1, + kX = 2, + kZ = 3, + kNumDimensions = 4, +}; + +// Creates a tensor shape object for the dimension lengths. +TensorShapeT MakeTensorShape(const std::vector& dimensions); + +// Create a tensor shape object with range information for each dimension. +TensorShapeT MakeTensorShape(const std::vector& ranges); + +// Returns true if all dimensions have valid index ranges. +bool IsValidShape(const TensorShape& shape); +bool IsValidShape(const TensorShapeT& shape); + +// Returns an intersection of two shapes. It will return an invalid shape if +// there is no intersection. +TensorShapeT GetIntersectShape(const TensorShapeT& one, + const TensorShapeT& two); + +// Returns number of elemens in a tensor shape. +int GetNumElementsInShape(const TensorShape& shape); +int GetNumElementsInShape(const TensorShapeT& shape); + +// Returns the length of a shape dimension. +int GetDimensionLength(const TensorShape& shape, int dimension); +int GetDimensionLength(const TensorShapeT& shape, int dimension); + +// Returns true if a tensor element specified by the position is included in the +// shape. +bool IsElementInShape(const TensorShape& shape, + const std::vector& position); +bool IsElementInShape(const TensorShapeT& shape, + const std::vector& position); + +// Returns a row-major-packed layout for a tensor shape. +std::unique_ptr BuildPackedLayout(const TensorShape& shape); +std::unique_ptr BuildPackedLayout(const TensorShapeT& shape); + +// Returns true if all dimensions have valid index ranges. +bool IsValidLayout(const TensorLayout& layout); +bool IsValidLayout(const TensorLayoutT& layout); + +// Returns true if the layout has no padding. +bool IsNoPaddingLayout(const TensorLayout& layout); + +// Return the memory space size for the layout. This can be different from the +// number of valid elements in the layout due to stride. +int GetLayoutSizeInElements(const TensorLayout& layout); +int GetLayoutSizeInElements(const TensorLayoutT& layout); + +// Returns a linear memory index from a tensor position (a list of indexes). +int GetMemoryIndexFromPosition(const TensorLayout& layout, + const std::vector& position); +int GetMemoryIndexFromPosition(const TensorLayoutT& layout, + const std::vector& position); + +// Returns a linear memory index of a tensor's first element in memory. +int GetFirstMemoryIndexForShape(const TensorLayout& layout, + const TensorShapeT& shape); + +// Returns a linear memory index of a tensor's last element in memory. +int GetLastMemoryIndexForShape(const TensorLayout& layout, + const TensorShapeT& shape); + +// Returns true if all tensor elements are stored in a contiguous layout. +bool IsShapeInContiguousLayout(const TensorLayout& layout, + const TensorShapeT& shape); + +// Dumps shape information. +std::string DumpShape(const TensorShape& shape); +std::string DumpShape(const TensorShapeT& shape); + +// Dumps layout information. +std::string DumpLayout(const TensorLayout& layout); +std::string DumpLayout(const TensorLayoutT& layout); + +} // namespace tensor_util +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_TENSOR_UTIL_H_ diff --git a/api/timing.h b/api/timing.h new file mode 100644 index 0000000..22608d9 --- /dev/null +++ b/api/timing.h @@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_TIMING_H_ +#define DARWINN_API_TIMING_H_ +#include "port/stringprintf.h" +namespace platforms { +namespace darwinn { +namespace api { + +// Timing information for real-time/QoS scheduler when applicable. +struct Timing { + // Inference arrival rate, in FPS. + int fps{0}; + // Max execution time (MET), in milliseconds. + int max_execution_time_ms{0}; + // Tolerance, or how much an inference can be delayed. Also in milliseconds. + // 0 <= Tolerance <= (1/FPS - MET). + int tolerance_ms{0}; + + std::string Dump() const { + return StringPrintf("(%d FPS; max execution time %d ms; tolerance %d ms)", + fps, max_execution_time_ms, tolerance_ms); + } +}; + +// Trivial equality operator for Timing. +inline bool operator==(const Timing& lhs, const Timing& rhs) { + return lhs.fps == rhs.fps && + lhs.max_execution_time_ms == rhs.max_execution_time_ms && + lhs.tolerance_ms == rhs.tolerance_ms; +} + +inline bool operator!=(const Timing& lhs, const Timing& rhs) { + return !(lhs == rhs); +} + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_TIMING_H_ diff --git a/api/watchdog.cc b/api/watchdog.cc new file mode 100644 index 0000000..0f5bf5a --- /dev/null +++ b/api/watchdog.cc @@ -0,0 +1,390 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/watchdog.h" + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) +#include + +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/time.h" +#include "port/timer.h" + +namespace platforms { +namespace darwinn { +namespace api { + +namespace { + +inline int64 GetNextActivationId(int64 current_id) { + return (current_id == INT64_MAX) ? 0 : current_id + 1; +} + +} // namespace + +std::unique_ptr Watchdog::MakeWatchdog(int64 timeout_ns, + Expire expire) { + if (timeout_ns > 0) { + return gtl::MakeUnique(timeout_ns, expire); + } + return gtl::MakeUnique(); +} + +TimerFdWatchdog::TimerFdWatchdog(int64 timeout_ns, Expire expire) + : TimerFdWatchdog(timeout_ns, std::move(expire), + gtl::MakeUnique()) {} + +TimerFdWatchdog::TimerFdWatchdog(int64 timeout_ns, Expire expire, + std::unique_ptr timer) + : expire_(std::move(expire)), + timeout_ns_(timeout_ns), + timer_(std::move(timer)) { + CHECK_GT(timeout_ns_, 0); + watcher_thread_ = std::thread([this]() { Watcher(); }); +} + +TimerFdWatchdog::~TimerFdWatchdog() { + { + StdMutexLock lock(&mutex_); + // 'DESTROYED' indicates that the watcher_thread_ should exit the loop. + // In case the watchdog is still BARKING, we set it to DESTROYED, so the + // watcher_thread_ can gracefully exit after the callback returns. The only + // side effect is that state_ is DESTROYED even though callback is running. + // Should be okay since nobody will query the watchdog state after this. + CHECK(state_ == WatchdogState::INACTIVE || + state_ == WatchdogState::BARKING); + state_ = WatchdogState::DESTROYED; + CHECK_OK(timer_->Set(1)); + } + + watcher_thread_.join(); +} + +util::StatusOr TimerFdWatchdog::Activate() { + StdMutexLock lock(&mutex_); + switch (state_) { + case WatchdogState::ACTIVE: + break; // Already active: Return old activation_id. + case WatchdogState::BARKING: + return util::FailedPreconditionError( + "Cannot activate a barking watchdog."); + case WatchdogState::INACTIVE: + VLOG(5) << "Activating the watchdog."; + RETURN_IF_ERROR(timer_->Set(timeout_ns_)); + state_ = WatchdogState::ACTIVE; + activation_id_ = GetNextActivationId(activation_id_); + break; + case WatchdogState::DESTROYED: + return util::FailedPreconditionError( + "Cannot activate a destroyed watchdog."); + } + return activation_id_; +} + +util::Status TimerFdWatchdog::Signal() { + StdMutexLock lock(&mutex_); + switch (state_) { + case WatchdogState::ACTIVE: + VLOG(5) << "Signalling the watchdog."; + RETURN_IF_ERROR(timer_->Set(timeout_ns_)); + return util::OkStatus(); + case WatchdogState::BARKING: + return util::OkStatus(); + case WatchdogState::INACTIVE: + case WatchdogState::DESTROYED: + return util::FailedPreconditionError( + "Cannot signal an in-active / destroyed watchdog."); + } +} + +util::Status TimerFdWatchdog::Deactivate() { + StdMutexLock lock(&mutex_); + switch (state_) { + case WatchdogState::ACTIVE: + VLOG(5) << "De-activating an active watchdog."; + RETURN_IF_ERROR(timer_->Set(0)); + state_ = WatchdogState::INACTIVE; + return util::OkStatus(); + case WatchdogState::BARKING: + case WatchdogState::INACTIVE: + // Watchdog is either inactive or will become inactive. Nothing to do. + return util::OkStatus(); + case WatchdogState::DESTROYED: + return util::FailedPreconditionError( + "Cannot deactivate a destroyed watchdog."); + } +} + +util::Status TimerFdWatchdog::UpdateTimeout(int64 timeout_ns) { + if (timeout_ns <= 0) { + return util::InvalidArgumentError(StringPrintf( + "Watchdog timeout should be a positive integer. %lld was provided", + static_cast(timeout_ns))); + } + + StdMutexLock lock(&mutex_); + timeout_ns_ = timeout_ns; + + return util::OkStatus(); +} + +void TimerFdWatchdog::Watcher() { + while (true) { + auto expirations_or_error = timer_->Wait(); + CHECK_OK(expirations_or_error.status()); + auto expirations = expirations_or_error.ValueOrDie(); + if (expirations == 0) { + continue; + } + CHECK_EQ(expirations, 1); + + // Local copies of shared state to be used when we don't have the lock. + int64 activation_id = 0; + bool do_expire = false; + + // Acquire lock to query and update shared state. + { + StdMutexLock lock(&mutex_); + + if (state_ == WatchdogState::DESTROYED) { + VLOG(5) << "Callback watcher thread ended."; + return; + } + + if (state_ != WatchdogState::ACTIVE) { + VLOG(1) << "Timer got triggered but watchdog is not active."; + continue; + } + + do_expire = true; + state_ = WatchdogState::BARKING; + activation_id = activation_id_; + } + + if (do_expire) { + // Callback occurs outside locked region since it might take more time. + VLOG(2) << "Calling watchdog expiration callback with ID:" + << activation_id; + expire_(activation_id); + + // Acquire lock again to update shared state after calling expire_. + { + StdMutexLock lock(&mutex_); + // Watchdog might be destroyed while callback was running. + // In that case, retain the 'DESTROYED' state. + if (state_ != WatchdogState::DESTROYED) { + state_ = WatchdogState::INACTIVE; + } + } + } + } +} + +CountingWatch::~CountingWatch() { + StdMutexLock lock(&mutex_); + if (counter_ != 0) { + LOG(WARNING) << StringPrintf( + "Destructing counting watch while counter is %lld", + static_cast(counter_)); + } +} + +util::Status CountingWatch::Increment() { + StdMutexLock lock(&mutex_); + + if (counter_ == LLONG_MAX) { + return util::InternalError("Reached max counter value."); + } + + counter_++; + VLOG(5) << StringPrintf("Incrementing watch counter to %lld.", + static_cast(counter_)); + return watchdog_->Activate().status(); +} + +util::Status CountingWatch::Decrement() { + StdMutexLock lock(&mutex_); + if (counter_ <= 0) { + return util::FailedPreconditionError( + StringPrintf("Cannot decrement when counter is %lld.", + static_cast(counter_))); + } + + counter_--; + VLOG(5) << StringPrintf("Decrementing watch counter to %lld.", + static_cast(counter_)); + + RETURN_IF_ERROR(watchdog_->Signal()); + + if (counter_ == 0) { + RETURN_IF_ERROR(watchdog_->Deactivate()); + } + + return util::OkStatus(); +} + +CascadeWatchdog::CascadeWatchdog(const std::vector& configs) + : CascadeWatchdog(configs, [](int64 timeout_ns, Expire expire) { + return gtl::MakeUnique(timeout_ns, std::move(expire)); + }) {} + +CascadeWatchdog::CascadeWatchdog(const std::vector& configs, + WatchdogMaker make_watchdog) + : configs_(configs) { + CHECK_GT(configs.size(), 0); + watchdogs_.reserve(configs.size()); + + expiration_callback_thread_ = std::thread([this]() { CallbackExecutor(); }); + + // Set callbacks for each watchdog. Note that there are 3 levels of callback. + // 'make_watchdog' has an anonymous method that calls 'WatchdogExpired', which + // in turn does some checks / book-keeping and invokes the actual callback + // that is registered in the 'configs_' vector. + for (int i = 0; i < configs.size(); ++i) { + watchdogs_.push_back(make_watchdog( + configs[i].timeout_ns, + [this, i](int64 activation_id) { WatchdogExpired(activation_id, i); })); + } +} + +CascadeWatchdog::~CascadeWatchdog() { + { + StdMutexLock lock(&mutex_); + is_alive_ = false; + child_expired_.notify_one(); + } + + expiration_callback_thread_.join(); +} + +void CascadeWatchdog::WatchdogExpired(int64 child_activation_id, int child_id) { + StdMutexLock lock(&mutex_); + if (child_activation_id != child_activation_id_ || + child_id != currently_active_) { + // This means this is a delayed callback for an earlier activation, we + // should skip it. + return; + } + + auto expire = configs_[currently_active_].expire; + auto activation_id = activation_id_; + expirations_.push_back([expire, activation_id]() { expire(activation_id); }); + child_expired_.notify_one(); + + if (currently_active_ < watchdogs_.size() - 1) { + ++currently_active_; + child_activation_id_ = + watchdogs_[currently_active_]->Activate().ValueOrDie(); + } else { + currently_active_ = kNoneActive; + } +} + +util::Status CascadeWatchdog::StartFirstWatchdog() { + ASSIGN_OR_RETURN(child_activation_id_, watchdogs_[0]->Activate()); + currently_active_ = 0; + return util::OkStatus(); +} + +util::Status CascadeWatchdog::DeactivateInternal() { + if (currently_active_ == kNoneActive) { + return util::OkStatus(); + } + + // There is a chance that we end up deactivating an already expired watchdog + // which will result in this call returning OK status but still getting the + // callback. However, callback notices that currently_active_ = kNoneActive + // and does not execute the expiration function. + RETURN_IF_ERROR(watchdogs_[currently_active_]->Deactivate()); + currently_active_ = kNoneActive; + + return util::OkStatus(); +} + +util::StatusOr CascadeWatchdog::Activate() { + StdMutexLock lock(&mutex_); + if (currently_active_ != kNoneActive) { + return activation_id_; + } + RETURN_IF_ERROR(StartFirstWatchdog()); + activation_id_ = GetNextActivationId(activation_id_); + return activation_id_; +} + +util::Status CascadeWatchdog::Signal() { + // Early exit if watchdog is not active + StdMutexLock lock(&mutex_); + if (currently_active_ == kNoneActive) { + VLOG(2) << "Signalled inactive CascadeWatchdog. Ignoring."; + return util::OkStatus(); + } + + RETURN_IF_ERROR(DeactivateInternal()); + return StartFirstWatchdog(); +} + +util::Status CascadeWatchdog::Deactivate() { + StdMutexLock lock(&mutex_); + return DeactivateInternal(); +} + +util::Status CascadeWatchdog::UpdateTimeout(int64 timeout_ns) { + return watchdogs_[0]->UpdateTimeout(timeout_ns); +} + +util::Status CascadeWatchdog::UpdateTimeout(int child_index, int64 timeout_ns) { + if (child_index >= watchdogs_.size()) { + return util::InvalidArgumentError(StringPrintf( + "Invalid child_index %d. We only have %zu child watchdogs.", + child_index, watchdogs_.size())); + } + return watchdogs_[child_index]->UpdateTimeout(timeout_ns); +} + +void CascadeWatchdog::CallbackExecutor() { + while (true) { + std::vector> expirations; + + { + StdCondMutexLock lock(&mutex_); + while (expirations_.empty() && is_alive_) { + child_expired_.wait(lock); + } + + if (!is_alive_) { + return; + } + + expirations = std::move(expirations_); + expirations_.clear(); + } + + for (const auto& expiration : expirations) { + expiration(); + } + } +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/api/watchdog.h b/api/watchdog.h new file mode 100644 index 0000000..695f58f --- /dev/null +++ b/api/watchdog.h @@ -0,0 +1,314 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_API_WATCHDOG_H_ +#define DARWINN_API_WATCHDOG_H_ + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" +#include "port/time.h" +#include "port/timer.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// Watchdog is a class responsible for keeping track of TPU status and sending +// notifications when it is unresponsive. +class Watchdog { + public: + // A callback function to be called when the watch timeout is reached. + using Expire = std::function; + + Watchdog() = default; + virtual ~Watchdog() = default; + + // This class is movable. + Watchdog(Watchdog&& rhs); + Watchdog& operator=(Watchdog&& rhs); + + // This class is not copyable. + Watchdog(const Watchdog&) = delete; + Watchdog& operator=(const Watchdog&) = delete; + + // Decides which watchdog concrete implementation to create based on the + // provided parameters, creates and returns it. + static std::unique_ptr MakeWatchdog(int64 timeout_ns, + Expire expire); + + // Starts the watch. It returns an activation id that can be later on + // used to verify which activation an expiration callback belongs to. + virtual util::StatusOr Activate() = 0; + + // Signals the watchdog that we are still active and healthy. + virtual util::Status Signal() = 0; + + // Ends the watch. + virtual util::Status Deactivate() = 0; + + // Updates watchdog timeout to the provided value in nanoseconds. By + // definition, the new timeout will be effective from the next activation / + // signal. + virtual util::Status UpdateTimeout(int64 timeout_ns) = 0; +}; + +// A No-Op watchdog used for when we don't need a watch (e.g. in tests, +// simulator, etc.). +class NoopWatchdog : public Watchdog { + public: + NoopWatchdog() = default; + ~NoopWatchdog() override = default; + + // This class is movable. + NoopWatchdog(NoopWatchdog&& rhs); + NoopWatchdog& operator=(NoopWatchdog&& rhs); + + // This class is not copyable. + NoopWatchdog(const NoopWatchdog&) = delete; + NoopWatchdog& operator=(const NoopWatchdog&) = delete; + + util::StatusOr Activate() override { return 0; } + util::Status Signal() override { return util::OkStatus(); } + util::Status Deactivate() override { return util::OkStatus(); } + util::Status UpdateTimeout(int64 timeout_ns) override { + return util::OkStatus(); + } +}; + +// A watchdog implementation that uses timerfd (or similar timers) underneath. +class TimerFdWatchdog : public Watchdog { + public: + // This constructor uses timerfd system call. + TimerFdWatchdog(int64 timeout_ns, Expire expire); + + // Accepts any timer interface. In most cases, it is recommended to use the + // first constructor. + TimerFdWatchdog(int64 timeout_ns, Expire expire, + std::unique_ptr timer); + + ~TimerFdWatchdog() override; + + // This class is movable. + TimerFdWatchdog(TimerFdWatchdog&& rhs); + TimerFdWatchdog& operator=(TimerFdWatchdog&& rhs); + + // This class is not copyable. + TimerFdWatchdog(const TimerFdWatchdog&) = delete; + TimerFdWatchdog& operator=(const TimerFdWatchdog&) = delete; + + enum class WatchdogState { + // State Transitions: + // |```````````````````V + // INACTIVE*-->ACTIVE-->BARKING-->INACTIVE-->DESTROYED + // ^--------------------^ + INACTIVE, // Not yet activated or has finished barking. + ACTIVE, // Activated, but not yet barked - signal now to prevent barking. + BARKING, // Activated, and timer expired - callback is being executed. + DESTROYED // Watchdog Destructor has been called - exit watcher thread. + }; + + const char* GetStateString() const EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + switch (state_) { + case WatchdogState::ACTIVE: + return "ACTIVE"; + case WatchdogState::INACTIVE: + return "INACTIVE"; + case WatchdogState::BARKING: + return "BARKING"; + case WatchdogState::DESTROYED: + return "DESTROYED"; + } + } + + util::StatusOr Activate() override LOCKS_EXCLUDED(mutex_); + util::Status Signal() override LOCKS_EXCLUDED(mutex_); + util::Status Deactivate() override LOCKS_EXCLUDED(mutex_); + util::Status UpdateTimeout(int64 timeout_ns) override LOCKS_EXCLUDED(mutex_); + + private: + // This function runs the watch thread that periodically checks the last time + // we heard anything. + void Watcher(); + + // Validates that the watchdog is currently active. + util::Status ValidateActive() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Callback function for when we time out. + const Expire expire_; + + // The amount of time watchdog has to be active and not get a signal in order + // for it to expire. + int64 timeout_ns_; + + // The timer to be used for keeping track of expiration deadlines. + std::unique_ptr timer_; + + // A single mutex to protect mutable fields in this class. + std::mutex mutex_; + + // Watchdog state machine: + WatchdogState state_ GUARDED_BY(mutex_){WatchdogState::INACTIVE}; + + // An id to verify the origin of an expiration callback. + int64 activation_id_ GUARDED_BY(mutex_){0}; + + // The watcher thread that runs Watcher() method. + std::thread watcher_thread_; +}; + +// A wrapper around Watchdog that keeps track of device/code-state health by +// keeping track of the number of things in a pipeline. +class CountingWatch { + public: + // Constructor expects a configured watchdog. Its expiration callback is + // called if Decrement is not called within the timeout and counter is not 0. + explicit CountingWatch(std::unique_ptr watchdog) + : watchdog_(std::move(watchdog)) {} + + ~CountingWatch(); + + // Increments the number of elements in the pipeline by 1. This will result in + // activating the watchdog. + util::Status Increment() LOCKS_EXCLUDED(mutex_); + + // Decrements the number of elements in the pipeline. It fails if counter has + // already reached 0. + util::Status Decrement() LOCKS_EXCLUDED(mutex_); + + private: + // The watchdog we are wrapping. + std::unique_ptr watchdog_; + + // A mutex to ensure thread-safety for accessing members. + std::mutex mutex_; + + // A counter to keep track of number of elements in the pipeline. + int64 counter_ GUARDED_BY(mutex_) = 0; +}; + +// CascadeWatchdog is a multi-level watchdog that has an expiration callback and +// timeout for each level. After activation, if first level timeout expires, +// its callback function gets called and the second watch gets activated +// immediately after. Signaling or de-activating this watchdog resets everything +// back to first level. +class CascadeWatchdog : public Watchdog { + public: + // Encapsulates the configuration needed for each level in the cascade. + struct Config { + // Expiration function for when this watch level expires. + Expire expire; + + // Timeout for triggering the watch (relative to the previous level). + int64 timeout_ns; + }; + + // Creates a CascadeWatchdog provided a vector of configs. The configs are + // used in the provided order meaning the first callback to get triggered is + // the one in configs[0]. There has to be at least one config. + explicit CascadeWatchdog(const std::vector& configs); + + ~CascadeWatchdog() override; + + // This class is movable. + CascadeWatchdog(CascadeWatchdog&& rhs); + CascadeWatchdog& operator=(CascadeWatchdog&& rhs); + + // This class is not copyable. + CascadeWatchdog(const CascadeWatchdog&) = delete; + CascadeWatchdog& operator=(const CascadeWatchdog&) = delete; + + util::StatusOr Activate() override LOCKS_EXCLUDED(mutex_); + util::Status Signal() override LOCKS_EXCLUDED(mutex_); + util::Status Deactivate() override LOCKS_EXCLUDED(mutex_); + + // Updates the timeout of the first child watchdog (the first one that expires + // ). Use the overloaded method for updating timeouts of other child + // watchdogs. + util::Status UpdateTimeout(int64 timeout_ns) override; + + // Updates the timeout of the child watchdog at the provided index. + util::Status UpdateTimeout(int child_index, int64 timeout_ns); + + protected: + // A method that can create and return a child watchdog to be used here. + using WatchdogMaker = std::function(int64, Expire)>; + + // A constructor that accepts a WatchdogMaker to use for creating the child + // watchdogs. + CascadeWatchdog(const std::vector& configs, + WatchdogMaker make_watchdog); + + // A vector of the underlying child watchdogs in the same order as configs. + std::vector> watchdogs_; + + private: + // The method that gets called in any of the child watchdogs expire. + void WatchdogExpired(int64 child_activation_id, int child_id) + LOCKS_EXCLUDED(mutex_); + + // The function responsible for executing expiration callbacks. + void CallbackExecutor(); + + // Start the first watchdog. Called by Activate and Signal. + util::Status StartFirstWatchdog() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Implement actual Deactivate method here to simplify some mutex locking. + util::Status DeactivateInternal() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // The list of watchdog configs as provided by the object owner. + const std::vector configs_; + + // A mutex to protect mutable class fields. + std::mutex mutex_; + + // Specifies which child watchdog is currently active. It can be used as an + // index to configs_ or watchdogs_. -1 means all watchdogs are inactive. + static constexpr int kNoneActive = -1; + int currently_active_ GUARDED_BY(mutex_){kNoneActive}; + + // The current/last generated activation ID for the caller of Activate on this + // class. + int64 activation_id_ GUARDED_BY(mutex_){0}; + + // At any given point in time at most 1 child watchdog is active. This field + // specifies the activation ID of that watchdog. + int64 child_activation_id_ GUARDED_BY(mutex_){0}; + + // A list of expiration callbacks that need to be executed. + std::vector> expirations_ GUARDED_BY(mutex_); + + // To notify that at least one child watchdog has expired. + std::condition_variable child_expired_; + + // The thread that executes expiration callbacks. + std::thread expiration_callback_thread_; + + // Specifies if watchdog is alive. This is used in the destructor to signal + // other threads that it is time to quit. + bool is_alive_ GUARDED_BY(mutex_){true}; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_API_WATCHDOG_H_ diff --git a/bazel/WORKSPACE b/bazel/WORKSPACE new file mode 100644 index 0000000..f9cb049 --- /dev/null +++ b/bazel/WORKSPACE @@ -0,0 +1,63 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +workspace(name = "libedgetpu") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "io_bazel_rules_closure", + sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", + strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", + urls = [ + "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 + ], +) + +# Be consistent with tensorflow/WORKSPACE. +http_archive( + name = "bazel_skylib", + sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz"], +) # https://github.com/bazelbuild/bazel-skylib/releases + +# The TF commit # here must be in sync with that specified under Gob edgetpu +# repo WORKSPACE file. +# TODO: figure out a way to keep single source of truth of the +# TF commit # used. +TENSORFLOW_COMMIT = "f394a768719a55b5c351ed1ecab2ec6f16f99dd4"; +# Command to calculate: curl -OL | sha256sum | awk '{print $1}' +TENSORFLOW_SHA256 = "cb286abee7ee9cf5c8701d85fcc88f0fd59e72492ec4f254156de486e3e905c1" +http_archive( + name = "org_tensorflow", + sha256 = TENSORFLOW_SHA256, + strip_prefix = "tensorflow-" + TENSORFLOW_COMMIT, + urls = [ + "https://github.com/tensorflow/tensorflow/archive/" + TENSORFLOW_COMMIT + ".tar.gz", + ], +) + +load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") +tf_workspace(tf_repo_name = "org_tensorflow") + +http_archive( + name = "coral_crosstool", + sha256 = "cb31b1417ccdcf7dd9fca5ec63e1571672372c30427730255997a547569d2feb", + strip_prefix = "crosstool-9e00d5be43bf001f883b5700f5d04882fea00229", + urls = [ + "https://github.com/google-coral/crosstool/archive/9e00d5be43bf001f883b5700f5d04882fea00229.tar.gz", + ], +) +load("@coral_crosstool//:configure.bzl", "cc_crosstool") +cc_crosstool(name = "crosstool") diff --git a/bazel/WORKSPACE.darwin b/bazel/WORKSPACE.darwin new file mode 100644 index 0000000..ad08cb5 --- /dev/null +++ b/bazel/WORKSPACE.darwin @@ -0,0 +1,13 @@ +# Use libusb from MacPorts. +new_local_repository( + name = "libusb", + path = "/opt/local/include/", + build_file_content = """ +cc_library( + name = "headers", + includes = ["."], + hdrs = ["libusb-1.0/libusb.h"], + visibility = ["//visibility:public"], +) +""" +) diff --git a/bazel/WORKSPACE.linux b/bazel/WORKSPACE.linux new file mode 100644 index 0000000..c33e185 --- /dev/null +++ b/bazel/WORKSPACE.linux @@ -0,0 +1,12 @@ +new_local_repository( + name = "libusb", + path = "/usr/include/", + build_file_content = """ +cc_library( + name = "headers", + includes = ["."], + hdrs = ["libusb-1.0/libusb.h"], + visibility = ["//visibility:public"], +) +""" +) diff --git a/bazel/WORKSPACE.windows b/bazel/WORKSPACE.windows new file mode 100644 index 0000000..9994174 --- /dev/null +++ b/bazel/WORKSPACE.windows @@ -0,0 +1,21 @@ +# For Windows, extract libusb archive to ../libusb-1.0.22 +# Unfortunately, can't use http_archive here, as bazel doesn't support 7z. +# https://github.com/libusb/libusb/releases/download/v1.0.22/libusb-1.0.22.7z +new_local_repository( + name = "libusb", + path = "../libusb-1.0.22", + build_file_content = """ +cc_library( + name = "headers", + includes = ["include"], + hdrs = ["include/libusb-1.0/libusb.h"], + visibility = ["//visibility:public"], +) +cc_import( + name = "shared", + interface_library = "MS64/dll/libusb-1.0.lib", + shared_library = "MS64/dll/libusb-1.0.dll", + visibility = ["//visibility:public"], +) +""" +) \ No newline at end of file diff --git a/build.bat b/build.bat new file mode 100644 index 0000000..12d0acc --- /dev/null +++ b/build.bat @@ -0,0 +1,74 @@ +:: Copyright 2019 Google LLC +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: https://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +echo off + +setlocal + +set THROTTLED=0 +set COMPILATION_MODE=opt +set OUT_DIR=%~dp0\out +set CPU=x64_windows + +:PROCESSARGS +set ARG=%1 +if defined ARG ( + if "%ARG%"=="/DBG" ( + set COMPILATION_MODE=dbg + ) else ( + set LEFTOVER_ARGS=%LEFTOVER_ARGS% %ARG% + ) + shift + goto PROCESSARGS +) + +set BAZEL_INFO_FLAGS=^ +--experimental_repo_remote_exec +for /f %%i in ('bazel info %BAZEL_INFO_FLAGS% output_path') do set "BAZEL_OUTPUT_PATH=%%i" +set BAZEL_OUTPUT_PATH=%BAZEL_OUTPUT_PATH:/=\% +set BAZEL_OUT_DIR=%BAZEL_OUTPUT_PATH%\%CPU%-%COMPILATION_MODE%\bin + + +set TARGET=//tflite/public:edgetpu_direct_usb.dll +set BAZEL_VS=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools +set BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC +set BAZEL_BUILD_FLAGS= ^ +--experimental_repo_remote_exec ^ +--compilation_mode %COMPILATION_MODE% ^ +--define darwinn_portable=1 ^ +--copt=/DSTRIP_LOG=1 ^ +--copt=/DABSL_FLAGS_STRIP_NAMES ^ +--copt=/DEDGETPU_EXTERNAL_RELEASE_RUNTIME ^ +--copt=/GR- ^ +--copt=/DWIN32_LEAN_AND_MEAN ^ +--copt=/D_WINSOCKAPI_ ^ +--copt=/std:c++latest + +call "%BAZEL_VC%\Auxiliary\Build\vcvars64.bat" + +bazel build %BAZEL_BUILD_FLAGS% %LEFTOVER_ARGS% %TARGET% +md %OUT_DIR%\direct\%CPU% +copy %BAZEL_OUT_DIR%\tflite\public\edgetpu_direct_usb.dll ^ + %OUT_DIR%\direct\%CPU%\ +python %~dp0\rename_library.py ^ + --input_dll %OUT_DIR%\direct\%CPU%\edgetpu_direct_usb.dll ^ + --output_dll %OUT_DIR%\direct\%CPU%\edgetpu.dll + +set BAZEL_BUILD_FLAGS=%BAZEL_BUILD_FLAGS% --copt=/DTHROTTLE_EDGE_TPU +bazel build %BAZEL_BUILD_FLAGS% %LEFTOVER_ARGS% %TARGET% +md %OUT_DIR%\throttled\%CPU% +copy %BAZEL_OUT_DIR%\tflite\public\edgetpu_direct_usb.dll ^ + %OUT_DIR%\throttled\%CPU%\ +python %~dp0\rename_library.py ^ + --input_dll %OUT_DIR%\throttled\%CPU%\edgetpu_direct_usb.dll ^ + --output_dll %OUT_DIR%\throttled\%CPU%\edgetpu.dll \ No newline at end of file diff --git a/build_defs.bzl b/build_defs.bzl new file mode 100644 index 0000000..5d465fb --- /dev/null +++ b/build_defs.bzl @@ -0,0 +1,13 @@ +"""Utilities for darwinn.""" + +def darwinn_port_defines(): + """Generates a list of port defines suitable for the build. + + Returns: + List of defines. + """ + return select({ + "//:darwinn_portable": ["DARWINN_PORT_DEFAULT"], + "//:darwinn_firmware": ["DARWINN_PORT_FIRMWARE"], + "//conditions:default": ["DARWINN_PORT_GOOGLE3"], + }) diff --git a/driver/BUILD b/driver/BUILD new file mode 100644 index 0000000..49d6e9d --- /dev/null +++ b/driver/BUILD @@ -0,0 +1,454 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Portable DarwiNN driver implementation. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["libdarwinn_driver.lds"]) + +# Utility functions. +cc_library( + name = "util", + hdrs = [ + "bitfield.h", + ], + deps = [ + "//port:integral_types", + "//port:logging", + ], +) + +cc_library( + name = "allocator", + srcs = [ + "aligned_allocator.cc", + "allocator.cc", + ], + hdrs = [ + "aligned_allocator.h", + "allocator.h", + ], + deps = [ + "//api:allocated_buffer", + "//api:buffer", + "//port", + ], +) + +cc_library( + name = "device_buffer", + srcs = ["device_buffer.cc"], + hdrs = ["device_buffer.h"], + deps = [ + "//port", + ], +) + +cc_library( + name = "executable_util", + srcs = ["executable_util.cc"], + hdrs = ["executable_util.h"], + deps = [ + "//api:buffer", + "//executable:executable_fbs", + "//port", + ], +) + +# Driver Factory. +cc_library( + name = "driver_factory", + srcs = ["driver_factory.cc"] + select({ + "//:windows": ["driver_factory_windows.cc"], + "//:darwin": ["driver_factory_darwin.cc"], + "//conditions:default": ["driver_factory_default.cc"], + }), + hdrs = [ + "driver_factory.h", + ], + deps = [ + ":driver", + "//api:chip", + "//api:driver", + "//api:driver_factory", + "//api:driver_options_fbs", + "//api:driver_options_helper", + "//driver/config", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "driver_helper", + srcs = ["driver_helper.cc"], + hdrs = ["driver_helper.h"], + deps = [ + ":executable_util", + ":package_registry", + ":test_vector", + "//api:buffer", + "//api:chip", + "//api:driver", + "//api:package_reference", + "//api:request", + "//api:telemeter_interface", + "//api:timing", + "//executable:executable_fbs", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "instruction_buffers", + srcs = ["instruction_buffers.cc"], + hdrs = ["instruction_buffers.h"], + deps = [ + ":allocator", + ":device_buffer_mapper", + ":executable_util", + "//api:buffer", + "//executable:executable_fbs", + "//port", + "//port:tracing", + ], +) + +cc_library( + name = "run_controller", + srcs = ["run_controller.cc"], + hdrs = ["run_controller.h"], + deps = [ + ":hardware_structures", + "//driver/config", + "//driver/config:register_constants", + "//driver/registers", + "//port", + ], +) + +cc_library( + name = "scalar_core_controller", + srcs = ["scalar_core_controller.cc"], + hdrs = ["scalar_core_controller.h"], + deps = [ + "//driver/config", + "//driver/interrupt:interrupt_controller", + "//driver/registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "hardware_structures", + hdrs = ["hardware_structures.h"], + deps = [ + "//port:integral_types", + "//port:macros", + ], +) + +cc_library( + name = "driver", + srcs = ["driver.cc"], + hdrs = ["driver.h"], + deps = [ + ":default_telemeter", + ":device_buffer_mapper", + ":package_registry", + ":request", + ":tpu_request", + "@com_google_absl//absl/strings:str_format", + "//api:buffer", + "//api:chip", + "//api:driver", + "//api:execution_context_interface", + "//api:package_reference", + "//api:request", + "//api:telemeter_interface", + "//driver/memory:dma_direction", + "//driver/time_stamper", + "//executable:executable_fbs", + "//port", + "//port:blocking_counter", + "//port:shared_mutex", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "default_telemeter", + hdrs = ["default_telemeter.h"], + deps = [ + "//api:telemeter_interface", + ], +) + +cc_library( + name = "mmio_driver", + srcs = ["mmio_driver.cc"], + hdrs = ["mmio_driver.h"], + deps = [ + ":allocator", + ":device_buffer", + ":device_buffer_mapper", + ":dma_info_extractor", + ":driver", + ":hardware_structures", + ":package_registry", + ":real_time_dma_scheduler", + ":run_controller", + ":scalar_core_controller", + ":single_tpu_request", + ":top_level_handler", + ":tpu_request", + "//api:allocated_buffer", + "//api:buffer", + "//api:watchdog", + "//driver/config", + "//driver/config:register_constants", + "//driver/interrupt:interrupt_controller_interface", + "//driver/interrupt:interrupt_handler", + "//driver/interrupt:top_level_interrupt_manager", + "//driver/memory:address_space", + "//driver/memory:address_utilities", + "//driver/memory:dma_direction", + "//driver/memory:dram_allocator", + "//driver/memory:mmu_mapper", + "//driver/mmio:host_queue", + "//driver/registers", + "//driver/time_stamper", + "//driver/time_stamper:driver_time_stamper", + "//executable:executable_fbs", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "dma_info", + srcs = ["dma_info.cc"], + hdrs = ["dma_info.h"], + deps = [ + ":device_buffer", + "//port", + ], +) + +cc_library( + name = "device_buffer_mapper", + srcs = ["device_buffer_mapper.cc"], + hdrs = ["device_buffer_mapper.h"], + deps = [ + ":device_buffer", + ":hardware_structures", + "//api:buffer", + "//driver/memory:address_space", + "//driver/memory:address_utilities", + "//driver/memory:dma_direction", + "//port", + "//port:tracing", + ], +) + +cc_library( + name = "dma_info_extractor", + srcs = ["dma_info_extractor.cc"], + hdrs = ["dma_info_extractor.h"], + deps = [ + ":device_buffer_mapper", + ":dma_info", + ":package_registry", + "//driver/memory:address_utilities", + "//executable:executable_fbs", + "//port", + ], +) + +cc_library( + name = "dma_scheduler", + hdrs = ["dma_scheduler.h"], + deps = [ + ":dma_info", + ":tpu_request", + "//api:driver", + "//port", + ], +) + +cc_library( + name = "single_queue_dma_scheduler", + srcs = ["single_queue_dma_scheduler.cc"], + hdrs = ["single_queue_dma_scheduler.h"], + deps = [ + ":dma_info", + ":dma_scheduler", + ":tpu_request", + "//api:driver", + "//api:watchdog", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "real_time_dma_scheduler", + srcs = ["real_time_dma_scheduler.cc"], + hdrs = ["real_time_dma_scheduler.h"], + deps = [ + ":dma_info", + ":dma_scheduler", + ":package_registry", + ":single_queue_dma_scheduler", + ":tpu_request", + "@com_google_absl//absl/strings:str_format", + "//api:driver", + "//api:package_reference", + "//api:timing", + "//api:watchdog", + "//driver/time_stamper", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "dma_chunker", + srcs = ["dma_chunker.cc"], + hdrs = ["dma_chunker.h"], + deps = [ + ":device_buffer", + "//port", + ], +) + +cc_library( + name = "top_level_handler", + hdrs = ["top_level_handler.h"], + deps = ["//port"], +) + +cc_library( + name = "package_registry", + srcs = ["package_registry.cc"], + hdrs = ["package_registry.h"], + deps = [ + ":allocator", + ":device_buffer_mapper", + ":instruction_buffers", + ":package_verifier", + "//api:buffer", + "//api:chip", + "//api:driver_options_fbs", + "//api:execution_context_interface", + "//api:layer_information", + "//api:package_reference", + "//api:runtime_version", + "//driver/memory:dram_allocator", + "//executable:executable_fbs", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "package_verifier", + srcs = ["package_verifier.cc"], + hdrs = ["package_verifier.h"], + deps = [ + "//executable:executable_fbs", + "//port", + ], +) + +cc_library( + name = "tpu_request", + hdrs = ["tpu_request.h"], + deps = [ + ":dma_info", + ":package_registry", + "//api:buffer", + "//api:request", + "//port", + ], +) + +cc_library( + name = "single_tpu_request", + srcs = ["single_tpu_request.cc"], + hdrs = ["single_tpu_request.h"], + deps = [ + ":allocator", + ":device_buffer", + ":device_buffer_mapper", + ":dma_info", + ":dma_info_extractor", + ":executable_util", + ":hardware_structures", + ":instruction_buffers", + ":package_registry", + ":request", + ":tpu_request", + "//api:allocated_buffer", + "//api:buffer", + "//driver/memory:address_space", + "//driver/memory:dram_allocator", + "//executable:executable_fbs", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "request", + srcs = ["request.cc"], + hdrs = ["request.h"], + deps = [ + ":tpu_request", + "//api:request", + "//driver/time_stamper", + "//driver/time_stamper:driver_time_stamper_factory", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +filegroup( + name = "linker_script", + srcs = ["libdarwinn_driver.lds"], +) diff --git a/driver/aligned_allocator.cc b/driver/aligned_allocator.cc new file mode 100644 index 0000000..0896b06 --- /dev/null +++ b/driver/aligned_allocator.cc @@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/aligned_allocator.h" + +#include +#include + +#include "port/aligned_malloc.h" +#include "port/integral_types.h" +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +AlignedAllocator::AlignedAllocator(uint64 alignment_bytes) + : alignment_bytes_(alignment_bytes) { + // Check for power of 2, since we use arithmetic that relies on it elsewhere. + CHECK_EQ((alignment_bytes - 1) & alignment_bytes, 0); +} + +void* AlignedAllocator::Allocate(size_t size) { + int aligned_size = (size + alignment_bytes_ - 1) & ~(alignment_bytes_ - 1); + return aligned_malloc(aligned_size, alignment_bytes_); +} + +void AlignedAllocator::Free(void* aligned_memory) { + aligned_free(aligned_memory); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/aligned_allocator.h b/driver/aligned_allocator.h new file mode 100644 index 0000000..89aede7 --- /dev/null +++ b/driver/aligned_allocator.h @@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_ALIGNED_ALLOCATOR_H_ +#define DARWINN_DRIVER_ALIGNED_ALLOCATOR_H_ + +#include + +#include "driver/allocator.h" +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Convenience class to allocate aligned buffers. +class AlignedAllocator : public Allocator { + public: + // All allocated buffers will be aligned to |alignment_bytes| with a size + // granulairy of |alignment_bytes|. + explicit AlignedAllocator(uint64 alignment_bytes); + ~AlignedAllocator() = default; + + // This class is neither copyable nor movable. + AlignedAllocator(const AlignedAllocator&) = delete; + AlignedAllocator& operator=(const AlignedAllocator&) = delete; + + void* Allocate(size_t size) override; + void Free(void* aligned_memory) override; + + private: + // Alignment + const uint64 alignment_bytes_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_ALIGNED_ALLOCATOR_H_ diff --git a/driver/allocator.cc b/driver/allocator.cc new file mode 100644 index 0000000..6a16e8f --- /dev/null +++ b/driver/allocator.cc @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/allocator.h" + +#include +#include + +#include "api/allocated_buffer.h" +#include "api/buffer.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +Buffer Allocator::MakeBuffer(size_t size_bytes) { + auto free_cb = [this](void* ptr) { Free(ptr); }; + + uint8* ptr = static_cast(Allocate(size_bytes)); + return Buffer( + std::make_shared(ptr, size_bytes, std::move(free_cb))); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/allocator.h b/driver/allocator.h new file mode 100644 index 0000000..ad300ca --- /dev/null +++ b/driver/allocator.h @@ -0,0 +1,47 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_ALLOCATOR_H_ +#define DARWINN_DRIVER_ALLOCATOR_H_ + +#include + +#include "api/buffer.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interface for a class that can allocate host memory. +class Allocator { + public: + virtual ~Allocator() = default; + + // Allocates buffer of specified size. + virtual void* Allocate(size_t size) = 0; + + // Frees a previous allocated buffer. + virtual void Free(void* buffer) = 0; + + // Allocates and returns a buffer of the specified size. The lifecycle of the + // the returned buffer is tied to the Allocator instance. It is thus important + // to ensure that the allocator class outlives the returned buffer instances. + Buffer MakeBuffer(size_t size_bytes); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_ALLOCATOR_H_ diff --git a/driver/beagle/BUILD b/driver/beagle/BUILD new file mode 100644 index 0000000..b38a989 --- /dev/null +++ b/driver/beagle/BUILD @@ -0,0 +1,259 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Beagle Specific functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +# Provides Beagle USB Driver for internal (=no_external_release) usage. +# libUSB is statically linked in this version. +cc_library( + name = "beagle_usb_driver_provider_no_external_release", + srcs = ["beagle_usb_driver_provider.cc"], + deps = [ + ":beagle_top_level_handler", + ":beagle_top_level_interrupt_manager", + "@com_google_absl//absl/strings", + "//api:chip", + "//api:driver", + "//api:driver_options_fbs", + "//driver:allocator", + "//driver:driver_factory", + "//driver:package_registry", + "//driver:package_verifier", + "//driver:run_controller", + "//driver:scalar_core_controller", + "//driver/config/beagle:beagle_chip_config", + "//driver/interrupt:grouped_interrupt_controller", + "//driver/interrupt:interrupt_controller", + "//driver/interrupt:interrupt_controller_interface", + "//driver/memory:null_dram_allocator", + "//driver/time_stamper:driver_time_stamper", + "//driver/usb:local_usb_device_no_external_release", + "//driver/usb:usb_device_interface", + "//driver/usb:usb_driver", + "//driver/usb:usb_ml_commands", + "//driver/usb:usb_registers", + "//port", + "//port:tracing", + ], + alwayslink = 1, +) + +# Provides Beagle USB Driver. +# libUSB is dynamically linked in this version. +cc_library( + name = "beagle_usb_driver_provider", + srcs = ["beagle_usb_driver_provider.cc"], + deps = [ + ":beagle_top_level_handler", + ":beagle_top_level_interrupt_manager", + "@com_google_absl//absl/strings", + "//api:chip", + "//api:driver", + "//api:driver_options_fbs", + "//driver:allocator", + "//driver:driver_factory", + "//driver:package_registry", + "//driver:package_verifier", + "//driver:run_controller", + "//driver:scalar_core_controller", + "//driver/config/beagle:beagle_chip_config", + "//driver/interrupt:grouped_interrupt_controller", + "//driver/interrupt:interrupt_controller", + "//driver/interrupt:interrupt_controller_interface", + "//driver/memory:null_dram_allocator", + "//driver/time_stamper:driver_time_stamper", + "//driver/usb:local_usb_device", + "//driver/usb:usb_device_interface", + "//driver/usb:usb_driver", + "//driver/usb:usb_ml_commands", + "//driver/usb:usb_registers", + "//port", + "//port:tracing", + ], + alwayslink = 1, +) + +# Provides Beagle PCI Driver. +cc_library( + name = "beagle_pci_driver_provider", + srcs = ["beagle_pci_driver_provider.cc"], + deps = [ + ":beagle_kernel_top_level_handler", + "//api:chip", + "//api:driver", + "//driver:allocator", + "//driver:driver_factory", + "//driver:hardware_structures", + "//driver:mmio_driver", + "//driver:package_registry", + "//driver:package_verifier", + "//driver:run_controller", + "//driver:scalar_core_controller", + "//driver/config", + "//driver/config/beagle:beagle_chip_config", + "//driver/interrupt:dummy_interrupt_controller", + "//driver/interrupt:grouped_interrupt_controller", + "//driver/interrupt:interrupt_handler", + "//driver/interrupt:top_level_interrupt_manager", + "//driver/kernel:kernel_coherent_allocator", + "//driver/kernel:kernel_interrupt_handler", + "//driver/kernel:kernel_mmu_mapper", + "//driver/kernel:kernel_registers", + "//driver/kernel:kernel_wire_interrupt_handler", + "//driver/memory:dual_address_space", + "//driver/memory:null_dram_allocator", + "//driver/mmio:host_queue", + "//driver/time_stamper:driver_time_stamper", + "//port", + ], + alwayslink = 1, +) + +# Provides Beagle USB/PCI Driver. Used by an Android side script to +# populate dependencies for a unified Beagle provider in Beagle NNAPI HAL. +cc_library( + name = "beagle_all_driver_provider", + srcs = [ + "beagle_pci_driver_provider.cc", + "beagle_usb_driver_provider.cc", + ], + deps = [ + ":beagle_kernel_top_level_handler", + ":beagle_top_level_handler", + ":beagle_top_level_interrupt_manager", + "@com_google_absl//absl/strings", + "//api:chip", + "//api:driver", + "//api:driver_options_fbs", + "//driver:allocator", + "//driver:driver_factory", + "//driver:hardware_structures", + "//driver:mmio_driver", + "//driver:package_registry", + "//driver:package_verifier", + "//driver:run_controller", + "//driver:scalar_core_controller", + "//driver/config", + "//driver/config/beagle:beagle_chip_config", + "//driver/interrupt:dummy_interrupt_controller", + "//driver/interrupt:grouped_interrupt_controller", + "//driver/interrupt:interrupt_controller", + "//driver/interrupt:interrupt_controller_interface", + "//driver/interrupt:interrupt_handler", + "//driver/interrupt:top_level_interrupt_manager", + "//driver/kernel:kernel_coherent_allocator", + "//driver/kernel:kernel_interrupt_handler", + "//driver/kernel:kernel_mmu_mapper", + "//driver/kernel:kernel_registers", + "//driver/kernel:kernel_wire_interrupt_handler", + "//driver/memory:dual_address_space", + "//driver/memory:null_dram_allocator", + "//driver/mmio:host_queue", + "//driver/time_stamper:driver_time_stamper", + "//driver/usb:local_usb_device", + "//driver/usb:usb_device_interface", + "//driver/usb:usb_driver", + "//driver/usb:usb_ml_commands", + "//driver/usb:usb_registers", + "//port", + "//port:tracing", + ], + alwayslink = 1, +) + +cc_library( + name = "beagle_top_level_interrupt_manager", + srcs = ["beagle_top_level_interrupt_manager.cc"], + hdrs = ["beagle_top_level_interrupt_manager.h"], + deps = [ + "//driver/config", + "//driver/interrupt:interrupt_controller_interface", + "//driver/interrupt:top_level_interrupt_manager", + "//driver/registers", + "//port", + ], +) + +cc_library( + name = "beagle_top_level_handler", + srcs = ["beagle_top_level_handler.cc"], + hdrs = ["beagle_top_level_handler.h"], + deps = [ + "//api:driver_options_fbs", + "//driver:top_level_handler", + "//driver/config", + "//driver/registers", + "//port", + ], +) + +cc_library( + name = "beagle_ioctl", + hdrs = ["beagle_ioctl.h"], +) + +cc_library( + name = "beagle_kernel_top_level_handler", + srcs = ["beagle_kernel_top_level_handler.cc"], + hdrs = ["beagle_kernel_top_level_handler.h"], + deps = [ + ":beagle_ioctl", + "//api:driver_options_fbs", + "//driver:top_level_handler", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +# Standalone shared library that can be linked into a binary built elsewhere. +cc_binary( + name = "libbeagle-usb.so", + + # Setting DT_SONAME for the result .so. + # See b/72234056 + linkopts = [ + "-Wl,-soname,libbeagle-usb.so", + "-Wl,--version-script=$(location //driver:libdarwinn_driver.lds)", + ], + linkshared = 1, + linkstatic = 1, + deps = [ + ":beagle_usb_driver_provider", + "//driver:libdarwinn_driver.lds", + ], +) + +# Standalone shared library that can be linked into a binary built elsewhere. +cc_binary( + name = "libbeagle-pci.so", + + # Setting DT_SONAME for the result .so. + # See b/72234056 + linkopts = [ + "-Wl,-soname,libbeagle-pci.so", + "-Wl,--version-script=$(location //driver:libdarwinn_driver.lds)", + ], + linkshared = 1, + linkstatic = 1, + deps = [ + ":beagle_pci_driver_provider", + "//driver:libdarwinn_driver.lds", + ], +) diff --git a/driver/beagle/beagle_ioctl.h b/driver/beagle/beagle_ioctl.h new file mode 100644 index 0000000..5f2e44e --- /dev/null +++ b/driver/beagle/beagle_ioctl.h @@ -0,0 +1,61 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Apex kernel-userspace interface definitions. + */ +#ifndef __APEX_IOCTL_H__ +#define __APEX_IOCTL_H__ + +#include +#ifndef __KERNEL__ +#include +#endif + +/* Clock Gating ioctl. */ +struct apex_gate_clock_ioctl { + /* Enter or leave clock gated state. */ + uint64_t enable; + + /* If set, enter clock gating state, regardless of custom block's + * internal idle state. + */ + uint64_t force_idle; +}; + +/* Performance expectation ioctl. */ +enum apex_performance_expectation { + APEX_PERFORMANCE_LOW = 0, + APEX_PERFORMANCE_MED = 1, + APEX_PERFORMANCE_HIGH = 2, + APEX_PERFORMANCE_MAX = 3, +}; + +struct apex_performance_expectation_ioctl { + /* Expected performance from apex. */ + uint32_t performance; +}; + +/* Base number for all Apex-common IOCTLs. */ +#define APEX_IOCTL_BASE 0x7F + +/* Enable/Disable clock gating. */ +#define APEX_IOCTL_GATE_CLOCK \ + _IOW(APEX_IOCTL_BASE, 0, struct apex_gate_clock_ioctl) + +/* Change performance expectation. */ +#define APEX_IOCTL_PERFORMANCE_EXPECTATION \ + _IOW(APEX_IOCTL_BASE, 1, struct apex_performance_expectation_ioctl) + +#endif /* __APEX_IOCTL_H__ */ diff --git a/driver/beagle/beagle_kernel_top_level_handler.cc b/driver/beagle/beagle_kernel_top_level_handler.cc new file mode 100644 index 0000000..fd0df0d --- /dev/null +++ b/driver/beagle/beagle_kernel_top_level_handler.cc @@ -0,0 +1,146 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/beagle/beagle_kernel_top_level_handler.h" + +#include +#include +#include +#include + +#include "api/driver_options_generated.h" +#include "driver/beagle/beagle_ioctl.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +BeagleKernelTopLevelHandler::BeagleKernelTopLevelHandler( + const std::string &device_path, api::PerformanceExpectation performance) + : device_path_(device_path), performance_(performance) {} + +util::Status BeagleKernelTopLevelHandler::DisableSoftwareClockGate() { + StdMutexLock lock(&mutex_); + + if (!clock_gated_) { + return util::Status(); // OK + } + + apex_gate_clock_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + ioctl_buffer.enable = 0; + + if (ioctl(fd_, APEX_IOCTL_GATE_CLOCK, &ioctl_buffer) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Could not Disable Clock Gating : %d (%s)", fd_, strerror(errno))); + } + + clock_gated_ = false; + + return util::Status(); // OK +} + +util::Status BeagleKernelTopLevelHandler::EnableSoftwareClockGate() { + StdMutexLock lock(&mutex_); + + if (clock_gated_) { + return util::Status(); // OK + } + + apex_gate_clock_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + ioctl_buffer.enable = 1; + + if (ioctl(fd_, APEX_IOCTL_GATE_CLOCK, &ioctl_buffer) != 0) { + return util::FailedPreconditionError( + StringPrintf("Could not Clock Gate : %d (%s)", fd_, strerror(errno))); + } + + clock_gated_ = true; + + return util::Status(); // OK +} + +util::Status BeagleKernelTopLevelHandler::Open() { + StdMutexLock lock(&mutex_); + if (fd_ != -1) { + return util::FailedPreconditionError("Device already open."); + } + + fd_ = open(device_path_.c_str(), O_RDWR); + if (fd_ < 0) { + return util::FailedPreconditionError( + StringPrintf("Device open failed : %d (%s)", fd_, strerror(errno))); + } + + return util::Status(); // OK +} + +util::Status BeagleKernelTopLevelHandler::Close() { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + close(fd_); + fd_ = -1; + + return util::Status(); // OK +} + +util::Status BeagleKernelTopLevelHandler::QuitReset() { + apex_performance_expectation_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + + switch (performance_) { + case api::PerformanceExpectation_Low: + ioctl_buffer.performance = APEX_PERFORMANCE_LOW; + break; + + case api::PerformanceExpectation_Medium: + ioctl_buffer.performance = APEX_PERFORMANCE_MED; + break; + + case api::PerformanceExpectation_High: + ioctl_buffer.performance = APEX_PERFORMANCE_HIGH; + break; + + case api::PerformanceExpectation_Max: + ioctl_buffer.performance = APEX_PERFORMANCE_MAX; + break; + + default: + return util::InvalidArgumentError( + StringPrintf("Bad performance setting %d.", performance_)); + } + + StdMutexLock lock(&mutex_); + if (ioctl(fd_, APEX_IOCTL_PERFORMANCE_EXPECTATION, &ioctl_buffer) != 0) { + LOG(WARNING) << StringPrintf( + "Could not set performance expectation : %d (%s)", fd_, + strerror(errno)); + } + + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/beagle/beagle_kernel_top_level_handler.h b/driver/beagle/beagle_kernel_top_level_handler.h new file mode 100644 index 0000000..085cd01 --- /dev/null +++ b/driver/beagle/beagle_kernel_top_level_handler.h @@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_BEAGLE_BEAGLE_KERNEL_TOP_LEVEL_HANDLER_H_ +#define DARWINN_DRIVER_BEAGLE_BEAGLE_KERNEL_TOP_LEVEL_HANDLER_H_ + +#include // NOLINT + +#include "api/driver_options_generated.h" +#include "driver/top_level_handler.h" +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Handles chip specific resets. +class BeagleKernelTopLevelHandler : public TopLevelHandler { + public: + BeagleKernelTopLevelHandler(const std::string &device_path, + api::PerformanceExpectation performance); + ~BeagleKernelTopLevelHandler() override = default; + + // Implements ResetHandler interface. + util::Status Open() override; + util::Status Close() override; + util::Status EnableSoftwareClockGate() override; + util::Status DisableSoftwareClockGate() override; + util::Status QuitReset() override; + + private: + // Device path. + const std::string device_path_; + + // File descriptor of the opened device. + int fd_ GUARDED_BY(mutex_){-1}; + + // Mutex that guards fd_. + std::mutex mutex_; + + // Chip starts in clock gated state. + bool clock_gated_{true}; + + // Performance setting. + const api::PerformanceExpectation performance_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_BEAGLE_BEAGLE_KERNEL_TOP_LEVEL_HANDLER_H_ diff --git a/driver/beagle/beagle_pci_driver_provider.cc b/driver/beagle/beagle_pci_driver_provider.cc new file mode 100644 index 0000000..a080221 --- /dev/null +++ b/driver/beagle/beagle_pci_driver_provider.cc @@ -0,0 +1,178 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "api/chip.h" +#include "api/driver.h" +#include "driver/aligned_allocator.h" +#include "driver/beagle/beagle_kernel_top_level_handler.h" +#include "driver/config/beagle/beagle_chip_config.h" +#include "driver/config/chip_structures.h" +#include "driver/driver_factory.h" +#include "driver/hardware_structures.h" +#include "driver/interrupt/dummy_interrupt_controller.h" +#include "driver/interrupt/grouped_interrupt_controller.h" +#include "driver/interrupt/interrupt_handler.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/kernel/kernel_coherent_allocator.h" +#include "driver/kernel/kernel_interrupt_handler.h" +#include "driver/kernel/kernel_mmu_mapper.h" +#include "driver/kernel/kernel_registers.h" +#include "driver/kernel/kernel_wire_interrupt_handler.h" +#include "driver/memory/dual_address_space.h" +#include "driver/memory/null_dram_allocator.h" +#include "driver/mmio/host_queue.h" +#include "driver/mmio_driver.h" +#include "driver/package_registry.h" +#include "driver/package_verifier.h" +#include "driver/run_controller.h" +#include "driver/scalar_core_controller.h" +#include "driver/time_stamper/driver_time_stamper.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +using platforms::darwinn::api::Chip; +using platforms::darwinn::api::Device; + +} // namespace + +class BeaglePciDriverProvider : public DriverProvider { + public: + static std::unique_ptr CreateDriverProvider() { + return gtl::WrapUnique(new BeaglePciDriverProvider()); + } + + ~BeaglePciDriverProvider() override = default; + + std::vector Enumerate() override; + bool CanCreate(const Device& device) override; + util::StatusOr> CreateDriver( + const Device& device, const api::DriverOptions& options) override; + + private: + BeaglePciDriverProvider() = default; +}; + +REGISTER_DRIVER_PROVIDER(BeaglePciDriverProvider); + +std::vector BeaglePciDriverProvider::Enumerate() { + return EnumerateSysfs("apex", Chip::kBeagle, Device::Type::PCI); +} + +bool BeaglePciDriverProvider::CanCreate(const Device& device) { + return device.type == Device::Type::PCI && device.chip == Chip::kBeagle; +} + +util::StatusOr> +BeaglePciDriverProvider::CreateDriver(const Device& device, + const api::DriverOptions& options) { + if (!CanCreate(device)) { + return util::NotFoundError("Unsupported device."); + } + + // TODO: Following queue size could come from a config. + constexpr int kInstructionQueueSize = 256; + + // Coherent memory block granted to the Host Queue + constexpr int kCoherentAllocatorMaxSizeByte = 0x4000; + + auto config = gtl::MakeUnique(); + + // Offsets are embedded in the CSR spec. + constexpr uint64 kTileConfig0Offset = 0x40000; + constexpr uint64 kScalarCoreOffset = 0x44000; + constexpr uint64 kUserHibOffset = 0x48000; + + // Memory mapping must be aligned with page size. Assuming 4KB page size. + constexpr uint64 kSectionSize = 0x1000; + + const std::vector regions = { + {kTileConfig0Offset, kSectionSize}, + {kScalarCoreOffset, kSectionSize}, + {kUserHibOffset, kSectionSize}, + }; + auto registers = gtl::MakeUnique(device.path, regions, + /*read_only=*/false); + + auto interrupt_handler = gtl::MakeUnique(device.path); + auto top_level_handler = gtl::MakeUnique( + device.path, options.performance_expectation()); + auto mmu_mapper = gtl::MakeUnique(device.path); + auto address_space = gtl::MakeUnique( + config->GetChipStructures(), mmu_mapper.get()); + int allocation_alignment_bytes = + config->GetChipStructures().allocation_alignment_bytes; + auto allocator = + gtl::MakeUnique(allocation_alignment_bytes); + auto coherent_allocator = gtl::MakeUnique( + device.path, allocation_alignment_bytes, kCoherentAllocatorMaxSizeByte); + auto host_queue = + gtl::MakeUnique>( + config->GetInstructionQueueCsrOffsets(), config->GetChipStructures(), + registers.get(), std::move(coherent_allocator), + kInstructionQueueSize, /*single_descriptor_mode=*/false); + + // Keeping the number of interrupt so MmioDriver would still register for four + // interrupt handlers. + constexpr int kNumTopLevelInterrupts = 4; + auto top_level_interrupt_controller = + gtl::MakeUnique(kNumTopLevelInterrupts); + + // TODO Bridge top level interrupts to higher level logic. + // TopLevelInterruptManager initialized with DummyInterruptController leaves + // top level interrupts not really handled. We will have to further + // extend TopLevelInterruptManager to bridge top level interrupt to + // application/driver logic. + auto top_level_interrupt_manager = gtl::MakeUnique( + std::move(top_level_interrupt_controller)); + + auto fatal_error_interrupt_controller = gtl::MakeUnique( + config->GetFatalErrorInterruptCsrOffsets(), registers.get()); + auto scalar_core_controller = + gtl::MakeUnique(*config, registers.get()); + auto run_controller = + gtl::MakeUnique(*config, registers.get()); + + auto dram_allocator = gtl::MakeUnique(); + + ASSIGN_OR_RETURN( + auto verifier, + MakeExecutableVerifier(flatbuffers::GetString(options.public_key()))); + auto executable_registry = gtl::MakeUnique( + device.chip, std::move(verifier), dram_allocator.get()); + auto time_stamper = gtl::MakeUnique(); + + return {gtl::MakeUnique( + options, std::move(config), std::move(registers), + std::move(dram_allocator), std::move(mmu_mapper), + std::move(address_space), std::move(allocator), std::move(host_queue), + std::move(interrupt_handler), std::move(top_level_interrupt_manager), + std::move(fatal_error_interrupt_controller), + std::move(scalar_core_controller), std::move(run_controller), + std::move(top_level_handler), std::move(executable_registry), + std::move(time_stamper))}; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/beagle/beagle_top_level_handler.cc b/driver/beagle/beagle_top_level_handler.cc new file mode 100644 index 0000000..6890d2a --- /dev/null +++ b/driver/beagle/beagle_top_level_handler.cc @@ -0,0 +1,265 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/beagle/beagle_top_level_handler.h" + +#include "driver/config/beagle_csr_helper.h" +#include "driver/config/common_csr_helper.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { + +using config::registers::ScuCtrl3; + +} // namespace + +BeagleTopLevelHandler::BeagleTopLevelHandler( + const config::ChipConfig& config, Registers* registers, bool use_usb, + api::PerformanceExpectation performance) + : cb_bridge_offsets_(config.GetCbBridgeCsrOffsets()), + hib_user_offsets_(config.GetHibUserCsrOffsets()), + misc_offsets_(config.GetMiscCsrOffsets()), + reset_offsets_(config.GetScuCsrOffsets()), + scalar_core_offsets_(config.GetScalarCoreCsrOffsets()), + tile_config_offsets_(config.GetTileConfigCsrOffsets()), + tile_offsets_(config.GetTileCsrOffsets()), + registers_(registers), + performance_(performance), + use_usb_(use_usb) { + CHECK(registers != nullptr); +} + +util::Status BeagleTopLevelHandler::Open() { + // By reading top level registers, figure out whether chip is in clock gated + // mode. + software_clock_gated_ = false; + hardware_clock_gated_ = false; + + // 1. Always disable inactive mode. + // Read the register to preserve other fields. + ASSIGN_OR_RETURN(uint32 scu_ctrl_0_reg, + registers_->Read32(reset_offsets_.scu_ctrl_0)); + config::registers::ScuCtrl0 helper(scu_ctrl_0_reg); + helper.set_rg_pcie_inact_phy_mode(0); + helper.set_rg_usb_inact_phy_mode(0); + RETURN_IF_ERROR(registers_->Write32(reset_offsets_.scu_ctrl_0, helper.raw())); + + // 2. Check "rg_gated_gcb". + // 0x0: deprecated + // 0x1: hardware clock gated + // 0x2: no clock gating + ASSIGN_OR_RETURN(uint32 scu_ctrl_2_reg, + registers_->Read32(reset_offsets_.scu_ctrl_2)); + config::registers::ScuCtrl2 scu_ctrl_2(scu_ctrl_2_reg); + if (scu_ctrl_2.rg_gated_gcb() == 0x1) { + hardware_clock_gated_ = true; + } + + return util::Status(); // OK +} + +util::Status BeagleTopLevelHandler::QuitReset() { + // Disable Sleep Mode (Partial Software Control) + // 1. Make "rg_force_sleep" to be b10. Read the register to preserve other + // fields. + // 2. Set GCB, AXI, and 8051 clock rate according to desired performance + // level. + ASSIGN_OR_RETURN(uint32 scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + config::registers::ScuCtrl3 helper(scu_ctrl_3_reg); + helper.set_rg_force_sleep(0b10); + + switch (performance_) { + case api::PerformanceExpectation_Low: + helper.set_gcb_clock_rate(ScuCtrl3::GcbClock::k63MHZ); + helper.set_axi_clock_rate(ScuCtrl3::AxiClock::k125MHZ); + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k250MHZ); + break; + + case api::PerformanceExpectation_Medium: + helper.set_gcb_clock_rate(ScuCtrl3::GcbClock::k125MHZ); + helper.set_axi_clock_rate(ScuCtrl3::AxiClock::k125MHZ); + if (use_usb_) { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k500MHZ); + } else { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k250MHZ); + } + break; + + case api::PerformanceExpectation_High: + helper.set_gcb_clock_rate(ScuCtrl3::GcbClock::k250MHZ); + helper.set_axi_clock_rate(ScuCtrl3::AxiClock::k125MHZ); + if (use_usb_) { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k500MHZ); + } else { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k250MHZ); + } + break; + + case api::PerformanceExpectation_Max: + helper.set_gcb_clock_rate(ScuCtrl3::GcbClock::k500MHZ); + if (use_usb_) { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k500MHZ); + helper.set_axi_clock_rate(ScuCtrl3::AxiClock::k250MHZ); + } else { + helper.set_usb_8051_clock_rate(ScuCtrl3::Usb8051Clock::k250MHZ); + helper.set_axi_clock_rate(ScuCtrl3::AxiClock::k125MHZ); + } + break; + + default: + return util::InvalidArgumentError( + StringPrintf("Bad performance setting %d.", performance_)); + } + + RETURN_IF_ERROR(registers_->Write32(reset_offsets_.scu_ctrl_3, helper.raw())); + + // 2. Poll until "cur_pwr_state" is 0x0. Other fields might change as well, + // hence "cur_pwr_state" field has to be explicitly checked. + ASSIGN_OR_RETURN(scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + helper.set_raw(scu_ctrl_3_reg); + while (helper.cur_pwr_state() != 0x0) { + ASSIGN_OR_RETURN(scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + helper.set_raw(scu_ctrl_3_reg); + } + + // 3. Confirm that moved out of reset by reading any CSR with known initial + // value. scalar core run control should be zero. + RETURN_IF_ERROR( + registers_->Poll(scalar_core_offsets_.scalarCoreRunControl, 0)); + + // 4. Enable idle register. + config::registers::IdleRegister idle_reg; + idle_reg.set_enable(); + idle_reg.set_counter(1); + RETURN_IF_ERROR( + registers_->Write(misc_offsets_.idleRegister, idle_reg.raw())); + + // 5. Update sleep/wake delay for tiles. toSleepDelay = 2, toWakeDelay = 30. + // Broadcast to tiles. + // TODO: helper uses 7-bits as defined by CSR. Extract bitwidth + // automatically for different chips. + config::registers::TileConfig<7> tile_config_reg; + tile_config_reg.set_broadcast(); + RETURN_IF_ERROR(registers_->Write(tile_config_offsets_.tileconfig0, + tile_config_reg.raw())); + // Wait until tileconfig0 is set correctly. Subsequent writes are going to + // tiles, but hardware does not guarantee correct ordering with previous + // write. + RETURN_IF_ERROR(registers_->Poll(tile_config_offsets_.tileconfig0, + tile_config_reg.raw())); + + config::registers::DeepSleep deep_sleep_reg; + deep_sleep_reg.set_to_sleep_delay(2); + deep_sleep_reg.set_to_wake_delay(30); + RETURN_IF_ERROR( + registers_->Write(tile_offsets_.deepSleep, deep_sleep_reg.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelHandler::EnableReset() { + // If already in reset, skip reset. Otherwise, HIB CSR accesses will not be + // valid. + ASSIGN_OR_RETURN(uint32 scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + config::registers::ScuCtrl3 helper(scu_ctrl_3_reg); + if (helper.rg_force_sleep() == 0x3) { + return util::Status(); // OK + } + + // Enable Sleep Mode (Partial Software Control). + if (!use_usb_) { + // Do Software Force GCB Idle. + // Make sure all outstanding DMAs are drained. Note that USB does not have + // to do step 1/2 as host controls DMAs. + // 1. Enable DMA pause. + RETURN_IF_ERROR(registers_->Write(hib_user_offsets_.dma_pause, 1)); + + // 2. Wait until DMA is paused. + RETURN_IF_ERROR(registers_->Poll(hib_user_offsets_.dma_paused, 1)); + } + + // Actual enable sleep mode. + // 3. Set "rg_force_sleep" to 0x3. Read the register to preserve other fields. + helper.set_rg_force_sleep(0x3); + RETURN_IF_ERROR(registers_->Write32(reset_offsets_.scu_ctrl_3, helper.raw())); + + // 4. Poll until "cur_pwr_state" becomes 0x2. Other fields might change as + // well, hence "cur_pwr_state" field has to be explicitly checked. + ASSIGN_OR_RETURN(scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + helper.set_raw(scu_ctrl_3_reg); + while (helper.cur_pwr_state() != 0x2) { + ASSIGN_OR_RETURN(scu_ctrl_3_reg, + registers_->Read32(reset_offsets_.scu_ctrl_3)); + helper.set_raw(scu_ctrl_3_reg); + } + + // 5. Clear BULK credit by pulsing LSBs of "gcbb_credit0". + RETURN_IF_ERROR(registers_->Write32(cb_bridge_offsets_.gcbb_credit0, 0xF)); + return registers_->Write32(cb_bridge_offsets_.gcbb_credit0, 0x0); +} + +util::Status BeagleTopLevelHandler::EnableHardwareClockGate() { + if (hardware_clock_gated_) { + return util::Status(); // OK + } + + // Enable Hardware Clock Gate (GCB) + // 1. Write "rg_gated_gcb" to 0x1. Read the register to preserve other fields. + ASSIGN_OR_RETURN(uint32 scu_ctrl_2_reg, + registers_->Read32(reset_offsets_.scu_ctrl_2)); + config::registers::ScuCtrl2 scu_ctrl_2(scu_ctrl_2_reg); + scu_ctrl_2.set_rg_gated_gcb(0x1); + RETURN_IF_ERROR( + registers_->Write32(reset_offsets_.scu_ctrl_2, scu_ctrl_2.raw())); + + hardware_clock_gated_ = true; + return util::Status(); // OK +} + +util::Status BeagleTopLevelHandler::DisableHardwareClockGate() { + if (!hardware_clock_gated_) { + return util::Status(); // OK + } + + // Disable Software Clock Gate (GCB) + // 1. Force clock on by writing "rg_gated_gcb" to 0x2. Read the register to + // preserve other fields. + ASSIGN_OR_RETURN(uint32 scu_ctrl_2_reg, + registers_->Read32(reset_offsets_.scu_ctrl_2)); + config::registers::ScuCtrl2 scu_ctrl_2(scu_ctrl_2_reg); + scu_ctrl_2.set_rg_gated_gcb(0x2); + RETURN_IF_ERROR( + registers_->Write32(reset_offsets_.scu_ctrl_2, scu_ctrl_2.raw())); + + hardware_clock_gated_ = false; + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/beagle/beagle_top_level_handler.h b/driver/beagle/beagle_top_level_handler.h new file mode 100644 index 0000000..d2d9e31 --- /dev/null +++ b/driver/beagle/beagle_top_level_handler.h @@ -0,0 +1,78 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_HANDLER_H_ +#define DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_HANDLER_H_ + +#include "api/driver_options_generated.h" +#include "driver/config/cb_bridge_csr_offsets.h" +#include "driver/config/chip_config.h" +#include "driver/config/hib_user_csr_offsets.h" +#include "driver/config/misc_csr_offsets.h" +#include "driver/config/scalar_core_csr_offsets.h" +#include "driver/config/scu_csr_offsets.h" +#include "driver/config/tile_config_csr_offsets.h" +#include "driver/config/tile_csr_offsets.h" +#include "driver/registers/registers.h" +#include "driver/top_level_handler.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Handles beagle resets. Only used in remote driver as this will be handled in +// kernel space in kernel driver. +class BeagleTopLevelHandler : public TopLevelHandler { + public: + BeagleTopLevelHandler(const config::ChipConfig& config, Registers* registers, + bool use_usb, api::PerformanceExpectation performance); + ~BeagleTopLevelHandler() override = default; + + // Implements ResetHandler interface. + util::Status Open() override; + util::Status QuitReset() override; + util::Status EnableReset() override; + util::Status EnableHardwareClockGate() override; + util::Status DisableHardwareClockGate() override; + + private: + // CSR offsets. + const config::CbBridgeCsrOffsets& cb_bridge_offsets_; + const config::HibUserCsrOffsets& hib_user_offsets_; + const config::MiscCsrOffsets& misc_offsets_; + const config::ScuCsrOffsets& reset_offsets_; + const config::ScalarCoreCsrOffsets& scalar_core_offsets_; + const config::TileConfigCsrOffsets& tile_config_offsets_; + const config::TileCsrOffsets& tile_offsets_; + + // CSR interface. + Registers* const registers_; + + // Select clock combinations for performance. + const api::PerformanceExpectation performance_; + + // Whether USB is used for Beagle. + const bool use_usb_; + + // True if clock gated. Starts at non-clock gated mode. + bool software_clock_gated_{false}; + bool hardware_clock_gated_{false}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_HANDLER_H_ diff --git a/driver/beagle/beagle_top_level_interrupt_manager.cc b/driver/beagle/beagle_top_level_interrupt_manager.cc new file mode 100644 index 0000000..57ca521 --- /dev/null +++ b/driver/beagle/beagle_top_level_interrupt_manager.cc @@ -0,0 +1,359 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/beagle/beagle_top_level_interrupt_manager.h" + +#include + +#include "driver/config/beagle_csr_helper.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Top Level Interrupt ids: +// https://g3doc.corp.google.com/platforms/darwinn/silo/g3doc/spec/index.md#interrupt-handling +constexpr int kThermalShutdownId = 0; +constexpr int kPcieErrorId = 1; +constexpr int kMbistId = 2; +constexpr int kThermalWarningId = 3; + +} // namespace + +BeagleTopLevelInterruptManager::BeagleTopLevelInterruptManager( + std::unique_ptr interrupt_controller, + const config::ChipConfig& config, Registers* registers) + : TopLevelInterruptManager(std::move(interrupt_controller)), + apex_csr_offsets_(config.GetApexCsrOffsets()), + scu_csr_offsets_(config.GetScuCsrOffsets()), + registers_(registers) { + CHECK(registers != nullptr); +} + +util::Status BeagleTopLevelInterruptManager::DoEnableInterrupts() { + RETURN_IF_ERROR(EnableThermalWarningInterrupt()); + RETURN_IF_ERROR(EnableMbistInterrupt()); + RETURN_IF_ERROR(EnablePcieErrorInterrupt()); + RETURN_IF_ERROR(EnableThermalShutdownInterrupt()); + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DoDisableInterrupts() { + RETURN_IF_ERROR(DisableThermalWarningInterrupt()); + RETURN_IF_ERROR(DisableMbistInterrupt()); + RETURN_IF_ERROR(DisablePcieErrorInterrupt()); + RETURN_IF_ERROR(DisableThermalShutdownInterrupt()); + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DoHandleInterrupt(int id) { + switch (id) { + case kThermalWarningId: + return HandleThermalWarningInterrupt(); + + case kMbistId: + return HandleMbistInterrupt(); + + case kPcieErrorId: + return HandlePcieErrorInterrupt(); + + case kThermalShutdownId: + return HandleThermalShutdownInterrupt(); + + default: + return util::InvalidArgumentError( + StringPrintf("Unknown top level id: %d", id)); + } +} + +util::Status BeagleTopLevelInterruptManager::EnableThermalWarningInterrupt() { + // 1. Enable thermal warning through omc0_d4. + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 omc0_d4_read, + registers_->Read32(apex_csr_offsets_.omc0_d4)); + driver::config::registers::Omc0D4 omc0_d4_helper(omc0_d4_read); + omc0_d4_helper.set_thm_warn_en(1); + + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.omc0_d4, omc0_d4_helper.raw())); + + // 2. Set thermal warning threshold temperature. + // TODO: This is important in the real chip, but not for DV purposes + // now. + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::EnableMbistInterrupt() { + // 1. Unmask interrupts, and clear interrupt status. + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 rambist_ctrl_1_read, + registers_->Read32(apex_csr_offsets_.rambist_ctrl_1)); + driver::config::registers::RamBistCtrl1 ram_bist_ctrl_1_helper( + rambist_ctrl_1_read); + // rg_mbist_int_status is write 1 to clear. Set it to 0 not to clear it. + ram_bist_ctrl_1_helper.set_rg_mbist_int_status(0); + ram_bist_ctrl_1_helper.set_rg_mbist_int_mask(0); + + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.rambist_ctrl_1, + ram_bist_ctrl_1_helper.raw())); + + // 2. Unmask interrupts, and clear interrupt status in scu_ctr_7. + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 scu_ctr_7_read, + registers_->Read32(scu_csr_offsets_.scu_ctr_7)); + driver::config::registers::ScuCtrl7 scu7_helper(scu_ctr_7_read); + // pll_lock_failure, and usb_sel_failure are write 1 to clear. Set them to 0 + // not to clear them. + scu7_helper.set_pll_lock_failure(0); + scu7_helper.set_usb_sel_failure(0); + scu7_helper.set_rg_boot_failure_mask(0); + + RETURN_IF_ERROR( + registers_->Write32(scu_csr_offsets_.scu_ctr_7, scu7_helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::EnablePcieErrorInterrupt() { + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 1)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 1)); + // Write 0x3 to unmask. + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.slv_err_resp_isr_mask, 0x3)); + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.mst_err_resp_isr_mask, 0x3)); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::EnableThermalShutdownInterrupt() { + // 1. Enable thermal shutdown through omc0_d8. + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 omc0_d8_read, + registers_->Read32(apex_csr_offsets_.omc0_d8)); + driver::config::registers::Omc0D8 omc0_d8_helper(omc0_d8_read); + omc0_d8_helper.set_sd_en(1); + + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.omc0_d8, omc0_d8_helper.raw())); + + // 2. Set thermal shutdown threshold temperature. + // TODO: This is important in the real chip, but not for DV purposes + // now. + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DisableThermalWarningInterrupt() { + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 omc0_d4_read, + registers_->Read32(apex_csr_offsets_.omc0_d4)); + driver::config::registers::Omc0D4 omc0_d4_helper(omc0_d4_read); + omc0_d4_helper.set_thm_warn_en(0); + + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.omc0_d4, omc0_d4_helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DisableMbistInterrupt() { + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 rambist_ctrl_1_read, + registers_->Read32(apex_csr_offsets_.rambist_ctrl_1)); + driver::config::registers::RamBistCtrl1 ram_bist_ctrl_1_helper( + rambist_ctrl_1_read); + ram_bist_ctrl_1_helper.set_rg_mbist_int_mask(0x7); + + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.rambist_ctrl_1, + ram_bist_ctrl_1_helper.raw())); + + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 scu_ctr_7_read, + registers_->Read32(scu_csr_offsets_.scu_ctr_7)); + driver::config::registers::ScuCtrl7 scu7_helper(scu_ctr_7_read); + scu7_helper.set_rg_boot_failure_mask(0x3); + + RETURN_IF_ERROR( + registers_->Write32(scu_csr_offsets_.scu_ctr_7, scu7_helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DisablePcieErrorInterrupt() { + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 0)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 0)); + // Write 0x0 to mask. + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.slv_err_resp_isr_mask, 0)); + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.mst_err_resp_isr_mask, 0)); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::DisableThermalShutdownInterrupt() { + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 omc0_d8_read, + registers_->Read32(apex_csr_offsets_.omc0_d8)); + driver::config::registers::Omc0D8 omc0_d8_helper(omc0_d8_read); + omc0_d8_helper.set_sd_en(0); + + RETURN_IF_ERROR( + registers_->Write32(apex_csr_offsets_.omc0_d8, omc0_d8_helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::HandleThermalWarningInterrupt() { + // Read register to preserve other fields. Also, warn_o field needs to be read + // before clearing warn_clear, i.e. read before write field. + ASSIGN_OR_RETURN(const uint32 omc0_dc_read, + registers_->Read32(apex_csr_offsets_.omc0_dc)); + driver::config::registers::Omc0DC helper(omc0_dc_read); + + // Unconditionally clears interrupts. Proper interrupt management has to + // handle the thermal warning and wait for temperature to go down below + // threshold before re-enabling. + if (helper.warn_o()) { + VLOG(5) << "Thermal warning interrupt received"; + helper.set_warn_clear(1); // Writes 1 to clear. + } + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.omc0_dc, helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::HandleMbistInterrupt() { + ASSIGN_OR_RETURN(const uint32 rambist_ctrl_1_read, + registers_->Read32(apex_csr_offsets_.rambist_ctrl_1)); + driver::config::registers::RamBistCtrl1 ram_bist_ctrl_1_helper( + rambist_ctrl_1_read); + + // Proper interrupt management is required in the real chip. For DV, just + // print whether we received the correct interrupt. + uint64 status_value = 0x0; + constexpr uint64 kMbistFail = 0x1; + if ((ram_bist_ctrl_1_helper.rg_mbist_int_status() & kMbistFail) == + kMbistFail) { + VLOG(5) << "Mbist fail interrupt received"; + status_value |= kMbistFail; + } + + constexpr uint64 kMbistTimeout = 0x2; + if ((ram_bist_ctrl_1_helper.rg_mbist_int_status() & kMbistTimeout) == + kMbistTimeout) { + VLOG(5) << "Mbist timeout interrupt received"; + status_value |= kMbistTimeout; + } + + constexpr uint64 kMbistFinish = 0x4; + if ((ram_bist_ctrl_1_helper.rg_mbist_int_status() & kMbistFinish) == + kMbistFinish) { + VLOG(5) << "Mbist finish interrupt received"; + status_value |= kMbistFinish; + } + + ram_bist_ctrl_1_helper.set_rg_mbist_int_status(status_value); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.rambist_ctrl_1, + ram_bist_ctrl_1_helper.raw())); + + // Read register to preserve other fields. + ASSIGN_OR_RETURN(const uint32 scu_ctr_7_read, + registers_->Read32(scu_csr_offsets_.scu_ctr_7)); + driver::config::registers::ScuCtrl7 scu7_helper(scu_ctr_7_read); + + // Proper interrupt management is required in the real chip. For DV, just + // print whether we received the correct interrupt. + if (scu7_helper.usb_sel_failure()) { + VLOG(5) << "bt_usb_sel violates the eFuse interrupt received"; + scu7_helper.set_usb_sel_failure(1); + } + if (scu7_helper.pll_lock_failure()) { + VLOG(5) << "PLL lock timeout interrupt received"; + scu7_helper.set_pll_lock_failure(1); + } + + RETURN_IF_ERROR( + registers_->Write32(scu_csr_offsets_.scu_ctr_7, scu7_helper.raw())); + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::HandlePcieErrorInterrupt() { + // Disable and enable abm_en to handle interrupts. + ASSIGN_OR_RETURN(const uint32 slave_write_error, + registers_->Read32(apex_csr_offsets_.slv_wr_err_resp)); + if (slave_write_error == 1) { + VLOG(5) << "Slave write interrupt received"; + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 0)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 1)); + } + ASSIGN_OR_RETURN(const uint32 slave_read_error, + registers_->Read32(apex_csr_offsets_.slv_rd_err_resp)); + if (slave_read_error == 1) { + VLOG(5) << "Slave read interrupt received"; + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 0)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.slv_abm_en, 1)); + } + + ASSIGN_OR_RETURN(const uint32 master_write_error, + registers_->Read32(apex_csr_offsets_.mst_wr_err_resp)); + if (master_write_error == 1) { + VLOG(5) << "Master write interrupt received"; + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 0)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 1)); + } + ASSIGN_OR_RETURN(const uint32 master_read_error, + registers_->Read32(apex_csr_offsets_.mst_rd_err_resp)); + if (master_read_error == 1) { + VLOG(5) << "Master read interrupt received"; + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 0)); + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.mst_abm_en, 1)); + } + + return util::Status(); // OK +} + +util::Status BeagleTopLevelInterruptManager::HandleThermalShutdownInterrupt() { + // Read register to preserve other fields. Also, sd_o field needs to be read + // before clearing sd_clear, i.e. read before write field. + ASSIGN_OR_RETURN(const uint32 omc0_dc_read, + registers_->Read32(apex_csr_offsets_.omc0_dc)); + driver::config::registers::Omc0DC helper(omc0_dc_read); + + // Unconditionally clears interrupts. Proper interrupt management has to + // handle the thermal shutdown, and wait for temperature to go down below + // threshold before re-enabling. + if (helper.sd_o()) { + VLOG(5) << "Thermal shutdown interrupt received"; + helper.set_sd_clear(1); // Writes 1 clear. + } + RETURN_IF_ERROR(registers_->Write32(apex_csr_offsets_.omc0_dc, helper.raw())); + + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/beagle/beagle_top_level_interrupt_manager.h b/driver/beagle/beagle_top_level_interrupt_manager.h new file mode 100644 index 0000000..dddf9da --- /dev/null +++ b/driver/beagle/beagle_top_level_interrupt_manager.h @@ -0,0 +1,79 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_INTERRUPT_MANAGER_H_ +#define DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_INTERRUPT_MANAGER_H_ + +#include + +#include "driver/config/apex_csr_offsets.h" +#include "driver/config/chip_config.h" +#include "driver/config/scu_csr_offsets.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/registers/registers.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Beagle-specific top level interrupt management. +class BeagleTopLevelInterruptManager : public TopLevelInterruptManager { + public: + BeagleTopLevelInterruptManager( + std::unique_ptr interrupt_controller, + const config::ChipConfig& config, Registers* registers); + ~BeagleTopLevelInterruptManager() override = default; + + protected: + // Implements interfaces. + util::Status DoEnableInterrupts() override; + util::Status DoDisableInterrupts() override; + util::Status DoHandleInterrupt(int id) override; + + private: + // Implements extra CSR managements to enable top level interrupts. + util::Status EnableThermalWarningInterrupt(); + util::Status EnableMbistInterrupt(); + util::Status EnablePcieErrorInterrupt(); + util::Status EnableThermalShutdownInterrupt(); + + // Implements extra CSR managements to disable top level interrupts. + util::Status DisableThermalWarningInterrupt(); + util::Status DisableMbistInterrupt(); + util::Status DisablePcieErrorInterrupt(); + util::Status DisableThermalShutdownInterrupt(); + + // Implements top level interrupt handling. + util::Status HandleThermalWarningInterrupt(); + util::Status HandleMbistInterrupt(); + util::Status HandlePcieErrorInterrupt(); + util::Status HandleThermalShutdownInterrupt(); + + // Apex CSR offsets. + const config::ApexCsrOffsets& apex_csr_offsets_; + + // SCU CSR offsets. + const config::ScuCsrOffsets scu_csr_offsets_; + + // CSR interface. + Registers* const registers_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_BEAGLE_BEAGLE_TOP_LEVEL_INTERRUPT_MANAGER_H_ diff --git a/driver/beagle/beagle_usb_driver_provider.cc b/driver/beagle/beagle_usb_driver_provider.cc new file mode 100644 index 0000000..3e8ad1a --- /dev/null +++ b/driver/beagle/beagle_usb_driver_provider.cc @@ -0,0 +1,374 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#if !defined(DARWINN_PORT_ANDROID_SYSTEM) && \ + !defined(DARWINN_PORT_ANDROID_EMULATOR) +#include "absl/strings/numbers.h" +#endif + +#include "api/chip.h" +#include "api/driver.h" +#include "api/driver_options_generated.h" +#include "driver/aligned_allocator.h" +#include "driver/beagle/beagle_top_level_handler.h" +#include "driver/beagle/beagle_top_level_interrupt_manager.h" +#include "driver/config/beagle/beagle_chip_config.h" +#include "driver/driver_factory.h" +#include "driver/interrupt/grouped_interrupt_controller.h" +#include "driver/interrupt/interrupt_controller.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/memory/null_dram_allocator.h" +#include "driver/package_registry.h" +#include "driver/package_verifier.h" +#include "driver/run_controller.h" +#include "driver/scalar_core_controller.h" +#include "driver/time_stamper/driver_time_stamper.h" +#include "driver/usb/local_usb_device.h" +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_driver.h" +#include "driver/usb/usb_ml_commands.h" +#include "driver/usb/usb_registers.h" +#include "port/gflags.h" +#include "port/ptr_util.h" +#include "port/tracing.h" + +namespace { +bool GetEnv(const char* env_var, bool default_value) { +#if !defined(DARWINN_PORT_ANDROID_SYSTEM) && \ + !defined(DARWINN_PORT_ANDROID_EMULATOR) + bool value; + const char* value_str = std::getenv(env_var); + if (value_str != nullptr && absl::SimpleAtob(value_str, &value)) return value; +#endif + return default_value; +} + +int GetEnv(const char* env_var, int default_value) { +#if !defined(DARWINN_PORT_ANDROID_SYSTEM) && \ + !defined(DARWINN_PORT_ANDROID_EMULATOR) + int value; + const char* value_str = std::getenv(env_var); + if (value_str != nullptr && absl::SimpleAtoi(value_str, &value)) return value; +#endif + return default_value; +} + +#if defined(__APPLE__) +constexpr int kDefaultUsbMaxNumAsyncTransfers = 1; +#else +constexpr int kDefaultUsbMaxNumAsyncTransfers = 3; +#endif +} // namespace + +/* + * There are only 3 modes of operation regarding + * usb_enable_bulk_out_descriptors_from_device and + * usb_enable_processing_of_hints: + * + * 1) both true, we follow the hints, and + * use descriptors sent from device as validation. This mode doesn't work if + * the device sends a lot of bulk-out or bulk-in descriptors out which could + * clog the descriptor/bulk-in pipeline. + * + * 2) disable descriptors but enable hints. We blindly follow the hints and + * send data to device as fast as we can. The mode is similar to the + * previous one, but could be slightly faster. + * + * 3) enable descriptors but disable the hints. we use descriptors from + * device and pretend there is no hint from code gen, except for first one + * (for instructions). This mode doesn't work with multiple instruction + * chunks, as device is not capable of generating descriptors for + * instructions. + * + */ +ABSL_FLAG(bool, usb_enable_bulk_descriptors_from_device, + GetEnv("USB_ENABLE_BULK_DESCRIPTORS_FROM_DEVICE", false), + "USB set to true if bulk in/out descriptors from device are needed."); +ABSL_FLAG(bool, usb_enable_processing_of_hints, + GetEnv("USB_ENABLE_PROCESSING_OF_HINTS", true), + "USB set to true for driver to proactively send data to device."); +ABSL_FLAG(int, usb_timeout_millis, GetEnv("USB_TIMEOUT_MILLIS", 6000), + "USB timeout in milliseconds"); +ABSL_FLAG(bool, usb_reset_back_to_dfu_mode, + GetEnv("USB_RESET_BACK_TO_DFU_MODE", false), + "USB find device in app mode, reset back to DFU mode, and terminate"); +ABSL_FLAG( + int, usb_software_credits_low_limit, + GetEnv("USB_SOFTWARE_CREDITS_LOW_LIMIT", 8192), + "USB lower bound of bulk out transfer size in bytes, when used in mode 1"); +ABSL_FLAG(int, usb_operating_mode, GetEnv("USB_OPERATING_MODE", 2), + "USB driver operating mode: 0:Multiple-Ep w/ HW, 1:Multiple-Ep w/ " + "SW, 2:Single-Ep"); +ABSL_FLAG(int, usb_max_bulk_out_transfer, + GetEnv("USB_MAX_BULK_OUT_TRANSFER", 1024 * 1024), + "USB max bulk out transfer size in bytes"); +ABSL_FLAG(int, usb_max_num_async_transfers, + GetEnv("USB_MAX_NUM_ASYNC_TRANSFERS", + kDefaultUsbMaxNumAsyncTransfers), + "USB max number of pending async bulk out transfer"); +ABSL_FLAG( + bool, usb_force_largest_bulk_in_chunk_size, + GetEnv("USB_FORCE_LARGEST_BULK_IN_CHUNK_SIZE", false), + "If true, bulk-in data is transmitted in largest chunks possible. Setting " + "this to true increase performance on USB2."); +ABSL_FLAG(bool, usb_enable_overlapping_requests, + GetEnv("USB_ENABLE_OVERLAPPING_REQUESTS", true), + "Allows the next queued request to be partially overlapped with " + "the current one."); +ABSL_FLAG(bool, usb_enable_overlapping_bulk_in_and_out, + GetEnv("USB_ENABLE_OVERLAPPING_BULK_IN_AND_OUT", true), + "Allows bulk-in trasnfer to be submitted before previous bulk-out " + "requests complete."); +ABSL_FLAG(bool, usb_enable_queued_bulk_in_requests, + GetEnv("USB_ENABLE_QUEUED_BULK_IN_REQUESTS", true), + "Allows bulk-in transfers to be queued to improve performance."); +ABSL_FLAG( + bool, usb_fail_if_slower_than_superspeed, + GetEnv("USB_FAIL_IF_SLOWER_THAN_SUPERSPEED", false), + "USB driver open would fail if the connection is slower than superspeed."); +ABSL_FLAG(int, usb_bulk_in_queue_capacity, + GetEnv("USB_BULK_IN_QUEUE_CAPACITY", 32), + "Max number of USB bulk-in requests that can be queued. This " + "option is only effective when it is positive."); + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +using platforms::darwinn::api::Chip; +using platforms::darwinn::api::Device; + +constexpr uint16_t kTargetAppVendorId = 0x18D1; +constexpr uint16_t kTargetAppProductId = 0x9302; + +constexpr uint16_t kTargetDfuVendorId = 0x1A6E; +constexpr uint16_t kTargetDfuProductId = 0x089A; + +// TODO add proper error handling to this function. +// Convenience function to read a file to a vector. +std::vector ReadToVector(const std::string& file_name) { + VLOG(10) << __func__ << file_name; + + // TODO directly read into the vector instead of transcopying through + // a string. + std::ifstream ifs(file_name); + std::string content_string((std::istreambuf_iterator(ifs)), + (std::istreambuf_iterator())); + + std::vector result; + auto data = reinterpret_cast(content_string.c_str()); + result.insert(result.end(), data, data + content_string.size()); + return result; +} +} // namespace + +class BeagleUsbDriverProvider : public DriverProvider { + public: + static std::unique_ptr CreateDriverProvider() { + return gtl::WrapUnique(new BeagleUsbDriverProvider()); + } + + ~BeagleUsbDriverProvider() override = default; + + std::vector Enumerate() override; + bool CanCreate(const Device& device) override; + util::StatusOr> CreateDriver( + const Device& device, const api::DriverOptions& options) override; + + private: + BeagleUsbDriverProvider() = default; +}; + +REGISTER_DRIVER_PROVIDER(BeagleUsbDriverProvider); + +std::vector BeagleUsbDriverProvider::Enumerate() { + TRACE_SCOPE("BeagleUsbDriverProvider::Enumerate"); + + LocalUsbDeviceFactory usb_device_factory; + std::vector device_list; + + auto usb_dfu_device_list_or_error = usb_device_factory.EnumerateDevices( + kTargetDfuVendorId, kTargetDfuProductId); + + auto usb_app_device_list_or_error = usb_device_factory.EnumerateDevices( + kTargetAppVendorId, kTargetAppProductId); + + if (usb_dfu_device_list_or_error.ok()) { + for (const auto& path : usb_dfu_device_list_or_error.ValueOrDie()) { + device_list.push_back({Chip::kBeagle, Device::Type::USB, path}); + VLOG(10) << StringPrintf("%s: adding path [%s]", __func__, path.c_str()); + } + } + + if (usb_app_device_list_or_error.ok()) { + for (const auto& path : usb_app_device_list_or_error.ValueOrDie()) { + device_list.push_back({Chip::kBeagle, Device::Type::USB, path}); + VLOG(10) << StringPrintf("%s: adding path [%s]", __func__, path.c_str()); + } + } + + return device_list; +} + +bool BeagleUsbDriverProvider::CanCreate(const Device& device) { + return device.type == Device::Type::USB && device.chip == Chip::kBeagle; +} + +util::StatusOr> +BeagleUsbDriverProvider::CreateDriver( + const Device& device, const api::DriverOptions& driver_options) { + TRACE_SCOPE("BeagleUsbDriverProvider::CreateDriver"); + + if (!CanCreate(device)) { + return util::NotFoundError("Unsupported device."); + } + + auto config = gtl::MakeUnique(); + + UsbDriver::UsbDriverOptions options; + options.usb_force_largest_bulk_in_chunk_size = + absl::GetFlag(FLAGS_usb_force_largest_bulk_in_chunk_size); + options.usb_enable_bulk_descriptors_from_device = + absl::GetFlag(FLAGS_usb_enable_bulk_descriptors_from_device); + options.usb_enable_processing_of_hints = + absl::GetFlag(FLAGS_usb_enable_processing_of_hints); + options.usb_max_num_async_transfers = + absl::GetFlag(FLAGS_usb_max_num_async_transfers); + options.mode = static_cast( + absl::GetFlag(FLAGS_usb_operating_mode)); + options.max_bulk_out_transfer_size_in_bytes = + absl::GetFlag(FLAGS_usb_max_bulk_out_transfer); + options.software_credits_lower_limit_in_bytes = + absl::GetFlag(FLAGS_usb_software_credits_low_limit); + options.usb_enable_overlapping_requests = + absl::GetFlag(FLAGS_usb_enable_overlapping_requests); + options.usb_enable_overlapping_bulk_in_and_out = + absl::GetFlag(FLAGS_usb_enable_overlapping_bulk_in_and_out); + options.usb_fail_if_slower_than_superspeed = + absl::GetFlag(FLAGS_usb_fail_if_slower_than_superspeed); + options.usb_enable_queued_bulk_in_requests = + absl::GetFlag(FLAGS_usb_enable_queued_bulk_in_requests); + options.usb_bulk_in_queue_capacity = + absl::GetFlag(FLAGS_usb_bulk_in_queue_capacity); + + auto usb_registers = gtl::MakeUnique(); + std::vector> + top_level_interrupt_controllers; + top_level_interrupt_controllers.push_back( + gtl::MakeUnique( + config->GetUsbTopLevel0InterruptCsrOffsets(), usb_registers.get())); + top_level_interrupt_controllers.push_back( + gtl::MakeUnique( + config->GetUsbTopLevel1InterruptCsrOffsets(), usb_registers.get())); + top_level_interrupt_controllers.push_back( + gtl::MakeUnique( + config->GetUsbTopLevel2InterruptCsrOffsets(), usb_registers.get())); + top_level_interrupt_controllers.push_back( + gtl::MakeUnique( + config->GetUsbTopLevel3InterruptCsrOffsets(), usb_registers.get())); + auto top_level_interrupt_controller = + gtl::MakeUnique( + &top_level_interrupt_controllers); + + auto top_level_interrupt_manager = + gtl::MakeUnique( + std::move(top_level_interrupt_controller), *config, + usb_registers.get()); + + auto fatal_error_interrupt_controller = gtl::MakeUnique( + config->GetUsbFatalErrorInterruptCsrOffsets(), usb_registers.get()); + + auto top_level_handler = gtl::MakeUnique( + *config, usb_registers.get(), + /*use_usb=*/true, driver_options.performance_expectation()); + + const api::DriverUsbOptions* usb_options = driver_options.usb(); + if (usb_options != nullptr) { + if (usb_options->dfu_firmware() != nullptr) { + auto provided_dfu_path = usb_options->dfu_firmware()->str(); + if (!provided_dfu_path.empty()) { + // try loading firmware into memory. + options.usb_firmware_image = ReadToVector(provided_dfu_path); + } + } + options.usb_always_dfu = usb_options->always_dfu(); + + // Override command line options if driver options are set. + // Command line options are easier to use for command line tools, but + // most other use cases should set the driver option. + + if (usb_options->has_fail_if_slower_than_superspeed()) { + options.usb_fail_if_slower_than_superspeed = + usb_options->fail_if_slower_than_superspeed(); + } + + if (usb_options->has_force_largest_bulk_in_chunk_size()) { + options.usb_force_largest_bulk_in_chunk_size = + usb_options->force_largest_bulk_in_chunk_size(); + } + + if (usb_options->has_enable_overlapping_bulk_in_and_out()) { + options.usb_enable_overlapping_bulk_in_and_out = + usb_options->enable_overlapping_bulk_in_and_out(); + } + + if (usb_options->has_enable_queued_bulk_in_requests()) { + options.usb_enable_queued_bulk_in_requests = + usb_options->enable_queued_bulk_in_requests(); + } + + if (usb_options->has_bulk_in_queue_capacity()) { + options.usb_bulk_in_queue_capacity = + usb_options->bulk_in_queue_capacity(); + } + } + + auto dram_allocator = gtl::MakeUnique(); + + std::string path(device.path); + ASSIGN_OR_RETURN(auto verifier, MakeExecutableVerifier(flatbuffers::GetString( + driver_options.public_key()))); + auto executable_registry = gtl::MakeUnique( + device.chip, std::move(verifier), dram_allocator.get()); + + auto time_stamper = gtl::MakeUnique(); + + // Note that although driver_options is passed into constructor of UsbDriver, + // it's USB portion is not used by the driver directly, due to historical + // reasons. + return {gtl::MakeUnique( + driver_options, std::move(config), + [path] { + LocalUsbDeviceFactory usb_device_factory; + + return usb_device_factory.OpenDevice( + path, absl::GetFlag(FLAGS_usb_timeout_millis)); + }, + std::move(usb_registers), std::move(top_level_interrupt_manager), + std::move(fatal_error_interrupt_controller), std::move(top_level_handler), + std::move(dram_allocator), std::move(executable_registry), options, + std::move(time_stamper))}; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/bitfield.h b/driver/bitfield.h new file mode 100644 index 0000000..7634832 --- /dev/null +++ b/driver/bitfield.h @@ -0,0 +1,104 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Helper class to get/set command status register (CSR) fields. Assumes 64-bit +// CSR registers. +// +// Usage: +// Bitfield<2, 3> field; // Read and write to 0b000XXX00 part of a byte +// +// // Access entire value with raw_value or read/write individual fields. +// // Unused bitfield can be left out. Bitfields are uninitialized because +// // they are expected to be used within union. +// union { +// uint64 raw_value; +// Bitfield<0, 1> enable; +// Bitfield<1, 9> status; +// } reg; +// +// // Explicitly assign default value if exists. +// reg.enable = 1; +// reg.status = 0xF; +// +// LSB_POSITION defines the starting bit of the field specified from the LSB. +// BITS defines the length of the field. Writes to the field that have bits set +// outside of BITS length will cause an error. The values passed in for setting +// and returned from reading will be right aligned (BITS bits starting from the +// LSB). + +#ifndef DARWINN_DRIVER_BITFIELD_H_ +#define DARWINN_DRIVER_BITFIELD_H_ + +#include +#include + +#include + +#include "port/integral_types.h" +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +template +class Bitfield { + public: + // Sets the bitfield to |value|. |value| is right aligned, and should set bits + // in the range of NUM_BITS. + Bitfield& operator=(uint64 value) { + CHECK_EQ(value & kMask, value); + + // Since Bitfield is expected to be used with unions, other bits from value + // must be preserved. + uint64 preserved_bits = value_ & ~(kMask << LSB_POSITION); + value_ = preserved_bits | (value << LSB_POSITION); + return *this; + } + + // Returns the value in a right aligned form. + constexpr uint64 operator()() const { + return (value_ >> LSB_POSITION) & kMask; + } + + // Returns mask for Bitfield. + constexpr uint64 mask() const { + return kMask; + } + + private: + // Supported bits in the underlying value. + static constexpr size_t kMaxBits = sizeof(uint64) * CHAR_BIT; + static_assert(NUM_BITS > 0, "Bitfield must use at least 1 bit"); + static_assert(NUM_BITS <= kMaxBits, + "Bitfield cannot have more bits than 64 bits"); + static_assert(LSB_POSITION < kMaxBits, + "Bitfield cannot start at LSB position higher than 63-bit"); + static_assert(LSB_POSITION + NUM_BITS <= kMaxBits, + "Bitfield cannot have its MSB position past 64-bit"); + + // Any attempt to write outside of kMask will cause an error. + static constexpr uint64 kMask = + (NUM_BITS == kMaxBits) ? (std::numeric_limits::max()) + : (1ULL << NUM_BITS) - 1; + + // Underlying value for Bitfield. + uint64 value_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_BITFIELD_H_ diff --git a/driver/config/BUILD b/driver/config/BUILD new file mode 100644 index 0000000..c2e79a6 --- /dev/null +++ b/driver/config/BUILD @@ -0,0 +1,92 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Chip specific configuration and CSR Layouts. +# +# When compiled in google3 environment, gen_rule generated headers in +# platforms/... is used directly. +# When compiled in non-google3 environment, pregenerated headers in +# ${PROJECT}/... is used. + +DEFAULT_VISIBILITY = [ + "//:internal", +] + +package(default_visibility = ["//visibility:public"]) + +# All Google Owned Code except : +# - certain files in port/default/ that are under Apache 2.0 license. +licenses(["notice"]) + +# Configuration structures. +cc_library( + name = "config", + hdrs = [ + "apex_csr_offsets.h", + "beagle_csr_helper.h", + "breakpoint_csr_offsets.h", + "cb_bridge_csr_offsets.h", + "chip_config.h", + "chip_structures.h", + "common_csr_helper.h", + "debug_hib_user_csr_offsets.h", + "debug_scalar_core_csr_offsets.h", + "debug_tile_csr_offsets.h", + "hib_kernel_csr_offsets.h", + "hib_user_csr_offsets.h", + "interrupt_csr_offsets.h", + "memory_csr_offsets.h", + "misc_csr_offsets.h", + "msix_csr_offsets.h", + "power_throttle_csr_helper.h", + "queue_csr_offsets.h", + "register_file_csr_offsets.h", + "scalar_core_csr_offsets.h", + "scu_csr_offsets.h", + "sync_flag_csr_offsets.h", + "tile_config_csr_offsets.h", + "tile_csr_offsets.h", + "tile_thread_csr_offsets.h", + "tile_thread_trace_csr_offsets.h", + "trace_csr_offsets.h", + "usb_csr_offsets.h", + "wire_csr_offsets.h", + ], + deps = [ + "//api:chip", + "//driver:util", + "//port:integral_types", + "//port:logging", + "//port:unreachable", + ], +) + +cc_library( + name = "register_constants", + hdrs = ["register_constants.h"], + deps = [ + "//port:integral_types", + ], +) + +cc_library( + name = "scalar_core_csr_offsets_helper", + hdrs = ["scalar_core_csr_offsets_helper.h"], + deps = [ + "//base", + "//driver:util", + "//util/bits", + ], +) diff --git a/driver/config/apex_csr_offsets.h b/driver/config/apex_csr_offsets.h new file mode 100644 index 0000000..f1e8c8f --- /dev/null +++ b/driver/config/apex_csr_offsets.h @@ -0,0 +1,54 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_APEX_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_APEX_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for apex in Beagle. +// Members are intentionally named to match the GCSR register names. +struct ApexCsrOffsets { + uint64 omc0_00; + + uint64 omc0_d4; + uint64 omc0_d8; + uint64 omc0_dc; + + uint64 mst_abm_en; + uint64 slv_abm_en; + uint64 slv_err_resp_isr_mask; + uint64 mst_err_resp_isr_mask; + + uint64 mst_wr_err_resp; + uint64 mst_rd_err_resp; + uint64 slv_wr_err_resp; + uint64 slv_rd_err_resp; + + uint64 rambist_ctrl_1; + + uint64 efuse_00; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_APEX_CSR_OFFSETS_H_ diff --git a/driver/config/beagle/BUILD b/driver/config/beagle/BUILD new file mode 100644 index 0000000..73f1c37 --- /dev/null +++ b/driver/config/beagle/BUILD @@ -0,0 +1,45 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Beagle-specific configuration and CSR Layouts. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(glob(["**/*"])) + +cc_library( + name = "beagle_config", + hdrs = [ + "beagle_chip_structures.h", + "beagle_csr_offsets.h", + ], + deps = [ + "//driver/config", + "//driver/config:register_constants", + ], +) + +cc_library( + name = "beagle_chip_config", + hdrs = ["beagle_chip_config.h"], + deps = [ + ":beagle_config", + "//driver/config", + "//port:logging", + "//port:unreachable", + ], +) diff --git a/driver/config/beagle/beagle_chip_config.h b/driver/config/beagle/beagle_chip_config.h new file mode 100644 index 0000000..d760e01 --- /dev/null +++ b/driver/config/beagle/beagle_chip_config.h @@ -0,0 +1,333 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_CONFIG_H_ +#define DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_CONFIG_H_ + +#include "driver/config/beagle/beagle_chip_structures.h" +#include "driver/config/beagle/beagle_csr_offsets.h" +#include "driver/config/chip_config.h" +#include "port/logging.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// Beagle-specific configuration. +class BeagleChipConfig : public ChipConfig { + public: + ~BeagleChipConfig() override = default; + + api::Chip GetChip() const override { return api::Chip::kBeagle; } + + // Extracts CSR offsets for various modules in DarwiNN. + const HibKernelCsrOffsets& GetHibKernelCsrOffsets() const override { + return kBeagleHibKernelCsrOffsets; + } + const HibUserCsrOffsets& GetHibUserCsrOffsets() const override { + return kBeagleHibUserCsrOffsets; + } + const QueueCsrOffsets& GetInstructionQueueCsrOffsets() const override { + return kBeagleInstructionQueueCsrOffsets; + } + const HibUserCsrOffsets& GetContextSpecificHibUserCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleHibUserCsrOffsets; + } + const QueueCsrOffsets& GetContextSpecificInstructionQueueCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleInstructionQueueCsrOffsets; + } + const InterruptCsrOffsets& GetContextSpecificScalarCoreInterruptCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleScHostIntInterruptCsrOffsets; + } + const InterruptCsrOffsets& GetContextSpecificTopLevelInterruptCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleTopLevelIntInterruptCsrOffsets; + } + const InterruptCsrOffsets& GetContextSpecificFatalErrorInterruptCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleFatalErrIntInterruptCsrOffsets; + } + const WireCsrOffsets& GetContextSpecificWireCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleWireCsrOffsets; + } + const DebugHibUserCsrOffsets& GetContextSpecificDebugHibUserCsrOffsets( + int context_id) const override { + CHECK_EQ(context_id, 0); + return kBeagleDebugHibUserCsrOffsets; + } + const ScalarCoreCsrOffsets& GetScalarCoreCsrOffsets() const override { + return kBeagleScalarCoreCsrOffsets; + } + const TileConfigCsrOffsets& GetTileConfigCsrOffsets() const override { + return kBeagleTileConfigCsrOffsets; + } + const TileCsrOffsets& GetTileCsrOffsets() const override { + return kBeagleTileCsrOffsets; + } + const InterruptCsrOffsets& GetScalarCoreInterruptCsrOffsets() const override { + return kBeagleScHostIntInterruptCsrOffsets; + } + const InterruptCsrOffsets& GetTopLevelInterruptCsrOffsets() const override { + return kBeagleTopLevelIntInterruptCsrOffsets; + } + const InterruptCsrOffsets& GetFatalErrorInterruptCsrOffsets() const override { + return kBeagleFatalErrIntInterruptCsrOffsets; + } + + // Extracts CSR offsets that supports specific functionality in DarwiNN. + const MsixCsrOffsets& GetMsixCsrOffsets() const override { + return kBeagleMsixCsrOffsets; + } + const WireCsrOffsets& GetWireCsrOffsets() const override { + LOG(FATAL) << "Wire interrupt not supported."; + unreachable(); + } + const MiscCsrOffsets& GetMiscCsrOffsets() const { + return kBeagleMiscCsrOffsets; + } + + // Extracts chip-specific constants in DarwiNN. + const ChipStructures& GetChipStructures() const override { + return kBeagleChipStructures; + } + + // Extracts CSR offsets used by scalar core debugger in DarwiNN. + const BreakpointCsrOffsets& GetScalarCoreBreakpointCsrOffsets() + const override { + return kBeagleScalarcoreBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetScalarCoreActivationTtuBreakpointCsrOffsets() + const override { + return kBeagleAvdatapopBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetScalarCoreInfeedTtuBreakpointCsrOffsets() + const override { + return kBeagleInfeedBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetScalarCoreOutfeedTtuBreakpointCsrOffsets() + const override { + return kBeagleOutfeedBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetScalarCoreParameterTtuBreakpointCsrOffsets() + const override { + return kBeagleParameterpopBreakpointCsrOffsets; + } + + const RegisterFileCsrOffsets& GetScalarRegisterFileCsrOffsets() + const override { + return kBeagleScalarRegisterFileCsrOffsets; + } + const RegisterFileCsrOffsets& GetPredicateRegisterFileCsrOffsets() + const override { + return kBeaglePredicateRegisterFileCsrOffsets; + } + + const MemoryCsrOffsets& GetScalarCoreMemoryCsrOffsets() const override { + return kBeagleScmemoryMemoryCsrOffsets; + } + + // Extracts CSR offsets used by tile debugger in DarwiNN. + const BreakpointCsrOffsets& GetTileOpTtuBreakpointCsrOffsets() const { + return kBeagleOpBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileWideToNarrowTtuBreakpointCsrOffsets() + const { + return kBeagleWidetonarrowBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileNarrowToWideTtuBreakpointCsrOffsets() + const { + return kBeagleNarrowtowideBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileRingBusConsumer0TtuBreakpointCsrOffsets() + const { + return kBeagleRingbusconsumer0BreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileRingBusConsumer1TtuBreakpointCsrOffsets() + const { + return kBeagleRingbusconsumer1BreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileRingBusProducerTtuBreakpointCsrOffsets() + const { + return kBeagleRingbusproducerBreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileMeshBus0TtuBreakpointCsrOffsets() const { + return kBeagleMeshbus0BreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileMeshBus1TtuBreakpointCsrOffsets() const { + return kBeagleMeshbus1BreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileMeshBus2TtuBreakpointCsrOffsets() const { + return kBeagleMeshbus2BreakpointCsrOffsets; + } + const BreakpointCsrOffsets& GetTileMeshBus3TtuBreakpointCsrOffsets() const { + return kBeagleMeshbus3BreakpointCsrOffsets; + } + + const MemoryCsrOffsets& GetTileMemoryCsrOffsets() const override { + return kBeagleMemoryMemoryCsrOffsets; + } + + // Extracts CSR offsets used by scalar core performance tracing. + const TraceCsrOffsets& GetScalarCoreActivationTtuTraceCsrOffsets() + const override { + return kBeagleAvdatapopTraceCsrOffsets; + } + const TraceCsrOffsets& GetScalarCoreInfeedTtuTraceCsrOffsets() + const override { + return kBeagleInfeedTraceCsrOffsets; + } + const TraceCsrOffsets& GetScalarCoreOutfeedTtuTraceCsrOffsets() + const override { + return kBeagleOutfeedTraceCsrOffsets; + } + const TraceCsrOffsets& GetScalarCoreParameterTtuTraceCsrOffsets() + const override { + return kBeagleParameterpopTraceCsrOffsets; + } + + // Extracts CSR offsets used by tile performance tracing. + const TraceCsrOffsets& GetTileOpTtuTraceCsrOffsets() const { + return kBeagleOpTraceCsrOffsets; + } + const TraceCsrOffsets& GetTileWideToNarrowTtuTraceCsrOffsets() const { + return kBeagleDmawidetonarrowTraceCsrOffsets; + } + const TraceCsrOffsets& GetTileNarrowToWideTtuTraceCsrOffsets() const { + return kBeagleDmanarrowtowideTraceCsrOffsets; + } + const TraceCsrOffsets& GetTileRingBusConsumer0TtuTraceCsrOffsets() const { + return kBeagleDmaringbusconsumer0TraceCsrOffsets; + } + const TraceCsrOffsets& GetTileRingBusConsumer1TtuTraceCsrOffsets() const { + return kBeagleDmaringbusconsumer1TraceCsrOffsets; + } + const TraceCsrOffsets& GetTileRingBusProducerTtuTraceCsrOffsets() const { + return kBeagleDmaringbusproducerTraceCsrOffsets; + } + const TraceCsrOffsets& GetTileMeshBus0TtuTraceCsrOffsets() const { + return kBeagleDmameshbus0TraceCsrOffsets; + } + const TraceCsrOffsets& GetTileMeshBus1TtuTraceCsrOffsets() const { + return kBeagleDmameshbus1TraceCsrOffsets; + } + const TraceCsrOffsets& GetTileMeshBus2TtuTraceCsrOffsets() const { + return kBeagleDmameshbus2TraceCsrOffsets; + } + const TraceCsrOffsets& GetTileMeshBus3TtuTraceCsrOffsets() const { + return kBeagleDmameshbus3TraceCsrOffsets; + } + + // Extracts CSR offsets used to access sync flags in scalar core. + const SyncFlagCsrOffsets& GetScalarCoreAvdataPopSyncFlagCsrOffsets() + const override { + return kBeagleAvdataPopSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreParameterPopSyncFlagCsrOffsets() + const override { + return kBeagleParameterPopSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreAvdataInfeedSyncFlagCsrOffsets() + const override { + return kBeagleAvdataInfeedSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreParameterInfeedSyncFlagCsrOffsets() + const override { + return kBeagleParameterInfeedSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreScalarInfeedSyncFlagCsrOffsets() + const override { + return kBeagleScalarInfeedSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreProducerASyncFlagCsrOffsets() + const override { + return kBeagleProducerASyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreProducerBSyncFlagCsrOffsets() + const override { + return kBeagleProducerBSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreRingOutfeedSyncFlagCsrOffsets() + const override { + return kBeagleRingOutfeedSyncFlagCsrOffsets; + } + const SyncFlagCsrOffsets& GetScalarCoreScalarPipelineSyncFlagCsrOffsets() + const override { + return kBeagleScalarPipelineSyncFlagCsrOffsets; + } + + // Extracts CSR offsets used by bug report generator in DarwiNN. + const DebugHibUserCsrOffsets& GetDebugHibUserCsrOffsets() const override { + return kBeagleDebugHibUserCsrOffsets; + } + const DebugScalarCoreCsrOffsets& GetDebugScalarCoreCsrOffsets() + const override { + return kBeagleDebugScalarCoreCsrOffsets; + } + const DebugTileCsrOffsets& GetDebugTileCsrOffsets() const override { + return kBeagleDebugTileCsrOffsets; + } + + // Beagle-specific. + const ApexCsrOffsets& GetApexCsrOffsets() const override { + return kBeagleApexCsrOffsets; + } + const ScuCsrOffsets& GetScuCsrOffsets() const override { + return kBeagleScuCsrOffsets; + } + const CbBridgeCsrOffsets& GetCbBridgeCsrOffsets() const override { + return kBeagleCbBridgeCsrOffsets; + } + const UsbCsrOffsets& GetUsbCsrOffsets() const override { + return kBeagleUsbCsrOffsets; + } + const InterruptCsrOffsets& GetUsbFatalErrorInterruptCsrOffsets() + const override { + return kBeagleUsbFatalErrIntInterruptCsrOffsets; + } + const InterruptCsrOffsets& GetUsbTopLevel0InterruptCsrOffsets() + const override { + return kBeagleUsbTopLevelInt0InterruptCsrOffsets; + } + const InterruptCsrOffsets& GetUsbTopLevel1InterruptCsrOffsets() + const override { + return kBeagleUsbTopLevelInt1InterruptCsrOffsets; + } + const InterruptCsrOffsets& GetUsbTopLevel2InterruptCsrOffsets() + const override { + return kBeagleUsbTopLevelInt2InterruptCsrOffsets; + } + const InterruptCsrOffsets& GetUsbTopLevel3InterruptCsrOffsets() + const override { + return kBeagleUsbTopLevelInt3InterruptCsrOffsets; + } +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_CONFIG_H_ diff --git a/driver/config/beagle/beagle_chip_structures.h b/driver/config/beagle/beagle_chip_structures.h new file mode 100644 index 0000000..51fb901 --- /dev/null +++ b/driver/config/beagle/beagle_chip_structures.h @@ -0,0 +1,58 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO GENERATED FILE. +// See http://go/darwinn-chip-structure for more info. + +#ifndef DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_STRUCTURES_H_ +#define DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_STRUCTURES_H_ + +#include "driver/config/chip_structures.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +const ChipStructures kBeagleChipStructures = { + 8ULL, // NOLINT: minimum_alignment_bytes + 4096ULL, // NOLINT: allocation_alignment_bytes + 0ULL, // NOLINT: axi_dma_burst_limiter + 0ULL, // NOLINT: num_wire_interrupts + 8192ULL, // NOLINT: num_page_table_entries + 64ULL, // NOLINT: physical_address_bits + 0ULL, // NOLINT: tpu_dram_size_bytes + 196608ULL, // NOLINT: narrow_memory_capacity + 262144ULL, // NOLINT: external_narrow_memory_translate_entry_size_bytes + 4ULL, // NOLINT: number_x_tiles + 4ULL, // NOLINT: number_y_tiles + 1ULL, // NOLINT: number_compute_threads + 0ULL, // NOLINT: number_of_ring_virtual_networks + 0ULL, // NOLINT: last_z_out_cell_disable_incompatible_with_sparsity + 0ULL, // NOLINT: nlu_buffer_backpressure_causes_assertion + 0ULL, // NOLINT: mesh_rx_queue_depth + 0ULL, // NOLINT: default_vn_buffer_memory_lines + 6291456ULL, // NOLINT: csr_region_base_offset + 2097152ULL, // NOLINT: csr_region_size_bytes + 0ULL, // NOLINT: support_trace_arch_registers + 0ULL, // NOLINT: base_and_bound_unit_size_bytes + 1ULL, // NOLINT: number_of_scalar_core_contexts +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CHIP_STRUCTURES_H_ diff --git a/driver/config/beagle/beagle_csr_offsets.h b/driver/config/beagle/beagle_csr_offsets.h new file mode 100644 index 0000000..73a1bed --- /dev/null +++ b/driver/config/beagle/beagle_csr_offsets.h @@ -0,0 +1,1705 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO GENERATED FILE. +// See http://go/darwinn-chip-structure for more info. + +#ifndef DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CSR_OFFSETS_H_ + +#include "driver/config/apex_csr_offsets.h" +#include "driver/config/breakpoint_csr_offsets.h" +#include "driver/config/cb_bridge_csr_offsets.h" +#include "driver/config/debug_hib_user_csr_offsets.h" +#include "driver/config/debug_scalar_core_csr_offsets.h" +#include "driver/config/debug_tile_csr_offsets.h" +#include "driver/config/hib_kernel_csr_offsets.h" +#include "driver/config/hib_user_csr_offsets.h" +#include "driver/config/interrupt_csr_offsets.h" +#include "driver/config/memory_csr_offsets.h" +#include "driver/config/misc_csr_offsets.h" +#include "driver/config/msix_csr_offsets.h" +#include "driver/config/queue_csr_offsets.h" +#include "driver/config/register_constants.h" +#include "driver/config/register_file_csr_offsets.h" +#include "driver/config/scalar_core_csr_offsets.h" +#include "driver/config/scu_csr_offsets.h" +#include "driver/config/sync_flag_csr_offsets.h" +#include "driver/config/tile_config_csr_offsets.h" +#include "driver/config/tile_csr_offsets.h" +#include "driver/config/trace_csr_offsets.h" +#include "driver/config/usb_csr_offsets.h" +#include "driver/config/wire_csr_offsets.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +const InterruptCsrOffsets kBeagleFatalErrIntInterruptCsrOffsets = { + 0x486c0, // NOLINT: fatal_err_int_control + 0x486c8, // NOLINT: fatal_err_int_status +}; + +const InterruptCsrOffsets kBeagleScHostIntInterruptCsrOffsets = { + 0x486a0, // NOLINT: sc_host_int_control + 0x486a8, // NOLINT: sc_host_int_status +}; + +const InterruptCsrOffsets kBeagleTopLevelIntInterruptCsrOffsets = { + 0x486b0, // NOLINT: top_level_int_control + 0x486b8, // NOLINT: top_level_int_status +}; + +const BreakpointCsrOffsets kBeagleAvdatapopBreakpointCsrOffsets = { + 0x44158, // NOLINT: avDataPopRunControl + 0x44168, // NOLINT: avDataPopRunStatus + 0x44160, // NOLINT: avDataPopBreakPoint +}; + +const BreakpointCsrOffsets kBeagleInfeedBreakpointCsrOffsets = { + 0x441d8, // NOLINT: infeedRunControl + 0x441e0, // NOLINT: infeedRunStatus + 0x441e8, // NOLINT: infeedBreakPoint +}; + +const BreakpointCsrOffsets kBeagleOutfeedBreakpointCsrOffsets = { + 0x44218, // NOLINT: outfeedRunControl + 0x44220, // NOLINT: outfeedRunStatus + 0x44228, // NOLINT: outfeedBreakPoint +}; + +const BreakpointCsrOffsets kBeagleParameterpopBreakpointCsrOffsets = { + 0x44198, // NOLINT: parameterPopRunControl + 0x441a8, // NOLINT: parameterPopRunStatus + 0x441a0, // NOLINT: parameterPopBreakPoint +}; + +const BreakpointCsrOffsets kBeagleScalarcoreBreakpointCsrOffsets = { + 0x44018, // NOLINT: scalarCoreRunControl + 0x44258, // NOLINT: scalarCoreRunStatus + 0x44020, // NOLINT: scalarCoreBreakPoint +}; + +const RegisterFileCsrOffsets kBeaglePredicateRegisterFileCsrOffsets = { + 0x44500, // NOLINT: predicateRegisterFile +}; + +const RegisterFileCsrOffsets kBeagleScalarRegisterFileCsrOffsets = { + 0x44400, // NOLINT: scalarRegisterFile +}; + +const SyncFlagCsrOffsets kBeagleAvdataInfeedSyncFlagCsrOffsets = { + 0x44060, // NOLINT: SyncCounter_AVDATA_INFEED +}; + +const SyncFlagCsrOffsets kBeagleAvdataPopSyncFlagCsrOffsets = { + 0x44050, // NOLINT: SyncCounter_AVDATA_POP +}; + +const SyncFlagCsrOffsets kBeagleParameterInfeedSyncFlagCsrOffsets = { + 0x44068, // NOLINT: SyncCounter_PARAMETER_INFEED +}; + +const SyncFlagCsrOffsets kBeagleParameterPopSyncFlagCsrOffsets = { + 0x44058, // NOLINT: SyncCounter_PARAMETER_POP +}; + +const SyncFlagCsrOffsets kBeagleProducerASyncFlagCsrOffsets = { + 0x44078, // NOLINT: SyncCounter_PRODUCER_A +}; + +const SyncFlagCsrOffsets kBeagleProducerBSyncFlagCsrOffsets = { + 0x44080, // NOLINT: SyncCounter_PRODUCER_B +}; + +const SyncFlagCsrOffsets kBeagleRingOutfeedSyncFlagCsrOffsets = { + 0x44088, // NOLINT: SyncCounter_RING_OUTFEED +}; + +const SyncFlagCsrOffsets kBeagleScalarInfeedSyncFlagCsrOffsets = { + 0x44070, // NOLINT: SyncCounter_SCALAR_INFEED +}; + +const SyncFlagCsrOffsets kBeagleScalarPipelineSyncFlagCsrOffsets = { + 0x44090, // NOLINT: SyncCounter_SCALAR_PIPELINE +}; + +const TraceCsrOffsets kBeagleAvdatapopTraceCsrOffsets = { + 0x44170, // NOLINT: avDataPopOverwriteMode + 0x44178, // NOLINT: avDataPopEnableTracing + 0x442c0, // NOLINT: avDataPopTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPopTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPopStallCauseSelect +}; + +const TraceCsrOffsets kBeagleInfeedTraceCsrOffsets = { + 0x441f0, // NOLINT: infeedOverwriteMode + 0x441f8, // NOLINT: infeedEnableTracing + 0x44340, // NOLINT: infeedTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeedTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeedStallCauseSelect +}; + +const TraceCsrOffsets kBeagleIrqcompletionbufferTraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, irqCompletionBufferOverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, irqCompletionBufferEnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, irqCompletionBufferTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, irqCompletionBufferTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // irqCompletionBufferStallCauseSelect +}; + +const TraceCsrOffsets kBeagleOutfeedTraceCsrOffsets = { + 0x44230, // NOLINT: outfeedOverwriteMode + 0x44238, // NOLINT: outfeedEnableTracing + 0x44380, // NOLINT: outfeedTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeedTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeedStallCauseSelect +}; + +const TraceCsrOffsets kBeagleParameterpopTraceCsrOffsets = { + 0x441b0, // NOLINT: parameterPopOverwriteMode + 0x441b8, // NOLINT: parameterPopEnableTracing + 0x44300, // NOLINT: parameterPopTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPopTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPopStallCauseSelect +}; + +const BreakpointCsrOffsets kBeagleMeshbus0BreakpointCsrOffsets = { + 0x42250, // NOLINT: meshBus0RunControl + 0x42258, // NOLINT: meshBus0RunStatus + 0x42260, // NOLINT: meshBus0BreakPoint +}; + +const BreakpointCsrOffsets kBeagleMeshbus1BreakpointCsrOffsets = { + 0x42298, // NOLINT: meshBus1RunControl + 0x422a0, // NOLINT: meshBus1RunStatus + 0x422a8, // NOLINT: meshBus1BreakPoint +}; + +const BreakpointCsrOffsets kBeagleMeshbus2BreakpointCsrOffsets = { + 0x422e0, // NOLINT: meshBus2RunControl + 0x422e8, // NOLINT: meshBus2RunStatus + 0x422f0, // NOLINT: meshBus2BreakPoint +}; + +const BreakpointCsrOffsets kBeagleMeshbus3BreakpointCsrOffsets = { + 0x42328, // NOLINT: meshBus3RunControl + 0x42330, // NOLINT: meshBus3RunStatus + 0x42338, // NOLINT: meshBus3BreakPoint +}; + +const BreakpointCsrOffsets kBeagleNarrowtowideBreakpointCsrOffsets = { + 0x42150, // NOLINT: narrowToWideRunControl + 0x42158, // NOLINT: narrowToWideRunStatus + 0x42160, // NOLINT: narrowToWideBreakPoint +}; + +const BreakpointCsrOffsets kBeagleOpBreakpointCsrOffsets = { + 0x420c0, // NOLINT: opRunControl + 0x420e0, // NOLINT: opRunStatus + 0x420d0, // NOLINT: opBreakPoint +}; + +const BreakpointCsrOffsets kBeagleRingbusconsumer0BreakpointCsrOffsets = { + 0x42190, // NOLINT: ringBusConsumer0RunControl + 0x42198, // NOLINT: ringBusConsumer0RunStatus + 0x421a0, // NOLINT: ringBusConsumer0BreakPoint +}; + +const BreakpointCsrOffsets kBeagleRingbusconsumer1BreakpointCsrOffsets = { + 0x421d0, // NOLINT: ringBusConsumer1RunControl + 0x421d8, // NOLINT: ringBusConsumer1RunStatus + 0x421e0, // NOLINT: ringBusConsumer1BreakPoint +}; + +const BreakpointCsrOffsets kBeagleRingbusproducerBreakpointCsrOffsets = { + 0x42210, // NOLINT: ringBusProducerRunControl + 0x42218, // NOLINT: ringBusProducerRunStatus + 0x42220, // NOLINT: ringBusProducerBreakPoint +}; + +const BreakpointCsrOffsets kBeagleWidetonarrowBreakpointCsrOffsets = { + 0x42110, // NOLINT: wideToNarrowRunControl + 0x42118, // NOLINT: wideToNarrowRunStatus + 0x42120, // NOLINT: wideToNarrowBreakPoint +}; + +const SyncFlagCsrOffsets kBeagleAvdataSyncFlagCsrOffsets = { + 0x42028, // NOLINT: SyncCounter_AVDATA +}; + +const SyncFlagCsrOffsets kBeagleMeshEastInSyncFlagCsrOffsets = { + 0x42048, // NOLINT: SyncCounter_MESH_EAST_IN +}; + +const SyncFlagCsrOffsets kBeagleMeshEastOutSyncFlagCsrOffsets = { + 0x42068, // NOLINT: SyncCounter_MESH_EAST_OUT +}; + +const SyncFlagCsrOffsets kBeagleMeshNorthInSyncFlagCsrOffsets = { + 0x42040, // NOLINT: SyncCounter_MESH_NORTH_IN +}; + +const SyncFlagCsrOffsets kBeagleMeshNorthOutSyncFlagCsrOffsets = { + 0x42060, // NOLINT: SyncCounter_MESH_NORTH_OUT +}; + +const SyncFlagCsrOffsets kBeagleMeshSouthInSyncFlagCsrOffsets = { + 0x42050, // NOLINT: SyncCounter_MESH_SOUTH_IN +}; + +const SyncFlagCsrOffsets kBeagleMeshSouthOutSyncFlagCsrOffsets = { + 0x42070, // NOLINT: SyncCounter_MESH_SOUTH_OUT +}; + +const SyncFlagCsrOffsets kBeagleMeshWestInSyncFlagCsrOffsets = { + 0x42058, // NOLINT: SyncCounter_MESH_WEST_IN +}; + +const SyncFlagCsrOffsets kBeagleMeshWestOutSyncFlagCsrOffsets = { + 0x42078, // NOLINT: SyncCounter_MESH_WEST_OUT +}; + +const SyncFlagCsrOffsets kBeagleNarrowToWideSyncFlagCsrOffsets = { + 0x42090, // NOLINT: SyncCounter_NARROW_TO_WIDE +}; + +const SyncFlagCsrOffsets kBeagleParametersSyncFlagCsrOffsets = { + 0x42030, // NOLINT: SyncCounter_PARAMETERS +}; + +const SyncFlagCsrOffsets kBeaglePartialSumsSyncFlagCsrOffsets = { + 0x42038, // NOLINT: SyncCounter_PARTIAL_SUMS +}; + +const SyncFlagCsrOffsets kBeagleRingProducerASyncFlagCsrOffsets = { + 0x420b0, // NOLINT: SyncCounter_RING_PRODUCER_A +}; + +const SyncFlagCsrOffsets kBeagleRingProducerBSyncFlagCsrOffsets = { + 0x420b8, // NOLINT: SyncCounter_RING_PRODUCER_B +}; + +const SyncFlagCsrOffsets kBeagleRingReadASyncFlagCsrOffsets = { + 0x42098, // NOLINT: SyncCounter_RING_READ_A +}; + +const SyncFlagCsrOffsets kBeagleRingReadBSyncFlagCsrOffsets = { + 0x420a0, // NOLINT: SyncCounter_RING_READ_B +}; + +const SyncFlagCsrOffsets kBeagleRingWriteSyncFlagCsrOffsets = { + 0x420a8, // NOLINT: SyncCounter_RING_WRITE +}; + +const SyncFlagCsrOffsets kBeagleWideToNarrowSyncFlagCsrOffsets = { + 0x42080, // NOLINT: SyncCounter_WIDE_TO_NARROW +}; + +const SyncFlagCsrOffsets kBeagleWideToScalingSyncFlagCsrOffsets = { + 0x42088, // NOLINT: SyncCounter_WIDE_TO_SCALING +}; + +const TraceCsrOffsets kBeagleDmameshbus0TraceCsrOffsets = { + 0x42270, // NOLINT: dmaMeshBus0OverwriteMode + 0x42278, // NOLINT: dmaMeshBus0EnableTracing + 0x42740, // NOLINT: dmaMeshBus0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus0TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus0StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmameshbus1TraceCsrOffsets = { + 0x422b8, // NOLINT: dmaMeshBus1OverwriteMode + 0x422c0, // NOLINT: dmaMeshBus1EnableTracing + 0x427c0, // NOLINT: dmaMeshBus1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus1TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus1StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmameshbus2TraceCsrOffsets = { + 0x42300, // NOLINT: dmaMeshBus2OverwriteMode + 0x42308, // NOLINT: dmaMeshBus2EnableTracing + 0x42840, // NOLINT: dmaMeshBus2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus2TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus2StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmameshbus3TraceCsrOffsets = { + 0x42348, // NOLINT: dmaMeshBus3OverwriteMode + 0x42350, // NOLINT: dmaMeshBus3EnableTracing + 0x428c0, // NOLINT: dmaMeshBus3Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus3TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus3StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtonarrowTraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToNarrowOverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToNarrowEnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToNarrowTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToNarrowTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToNarrowStallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide0TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_0TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_0StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide1TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_1TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_1StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide2TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_2OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_2EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_2TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_2StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide3TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_3OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_3EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_3Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_3TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_3StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide4TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_4OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_4EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_4Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_4TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_4StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide5TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_5OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_5EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_5Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_5TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_5StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide6TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_6OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_6EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_6Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_6TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_6StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowide7TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_7OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_7EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_7Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWide_7TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaNarrowToWide_7StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmanarrowtowideTraceCsrOffsets = { + 0x42168, // NOLINT: dmaNarrowToWideOverwriteMode + 0x42170, // NOLINT: dmaNarrowToWideEnableTracing + 0x42600, // NOLINT: dmaNarrowToWideTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWideTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWideStallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmaringbusconsumer0TraceCsrOffsets = { + 0x421a8, // NOLINT: dmaRingBusConsumer0OverwriteMode + 0x421b0, // NOLINT: dmaRingBusConsumer0EnableTracing + 0x42640, // NOLINT: dmaRingBusConsumer0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaRingBusConsumer0TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusConsumer0StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmaringbusconsumer1TraceCsrOffsets = { + 0x421e8, // NOLINT: dmaRingBusConsumer1OverwriteMode + 0x421f0, // NOLINT: dmaRingBusConsumer1EnableTracing + 0x42680, // NOLINT: dmaRingBusConsumer1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaRingBusConsumer1TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusConsumer1StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmaringbusproducerTraceCsrOffsets = { + 0x42228, // NOLINT: dmaRingBusProducerOverwriteMode + 0x42230, // NOLINT: dmaRingBusProducerEnableTracing + 0x426c0, // NOLINT: dmaRingBusProducerTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaRingBusProducerTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusProducerStallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow0TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_0TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_0StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow1TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_1TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_1StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow2TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_2OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_2EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_2TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_2StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow3TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_3OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_3EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_3Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_3TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_3StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow4TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_4OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_4EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_4Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_4TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_4StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow5TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_5OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_5EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_5Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_5TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_5StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow6TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_6OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_6EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_6Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_6TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_6StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrow7TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_7OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_7EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_7Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrow_7TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaWideToNarrow_7StallCauseSelect +}; + +const TraceCsrOffsets kBeagleDmawidetonarrowTraceCsrOffsets = { + 0x42128, // NOLINT: dmaWideToNarrowOverwriteMode + 0x42130, // NOLINT: dmaWideToNarrowEnableTracing + 0x42500, // NOLINT: dmaWideToNarrowTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrowTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrowStallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp0TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_0TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_0StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp1TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_1TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_1StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp2TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_2OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_2EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_2TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_2StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp3TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_3OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_3EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_3Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_3TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_3StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp4TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_4OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_4EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_4Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_4TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_4StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp5TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_5OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_5EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_5Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_5TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_5StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp6TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_6OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_6EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_6Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_6TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_6StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOp7TraceCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_7OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_7EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_7Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_7TimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, Op_7StallCauseSelect +}; + +const TraceCsrOffsets kBeagleOpTraceCsrOffsets = { + 0x420e8, // NOLINT: OpOverwriteMode + 0x420f0, // NOLINT: OpEnableTracing + 0x42400, // NOLINT: OpTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, OpTimeStampUnit + kCsrRegisterSpaceInvalidOffset, // UNUSED, OpStallCauseSelect +}; + +const DebugHibUserCsrOffsets kBeagleDebugHibUserCsrOffsets = { + 0x48010, // NOLINT: instruction_inbound_queue_total_occupancy + 0x48018, // NOLINT: instruction_inbound_queue_threshold_counter + 0x48020, // NOLINT: instruction_inbound_queue_insertion_counter + 0x48028, // NOLINT: instruction_inbound_queue_full_counter + 0x48030, // NOLINT: input_actv_inbound_queue_total_occupancy + 0x48038, // NOLINT: input_actv_inbound_queue_threshold_counter + 0x48040, // NOLINT: input_actv_inbound_queue_insertion_counter + 0x48048, // NOLINT: input_actv_inbound_queue_full_counter + 0x48050, // NOLINT: param_inbound_queue_total_occupancy + 0x48058, // NOLINT: param_inbound_queue_threshold_counter + 0x48060, // NOLINT: param_inbound_queue_insertion_counter + 0x48068, // NOLINT: param_inbound_queue_full_counter + 0x48070, // NOLINT: output_actv_inbound_queue_total_occupancy + 0x48078, // NOLINT: output_actv_inbound_queue_threshold_counter + 0x48080, // NOLINT: output_actv_inbound_queue_insertion_counter + 0x48088, // NOLINT: output_actv_inbound_queue_full_counter + 0x48090, // NOLINT: status_block_write_inbound_queue_total_occupancy + 0x48098, // NOLINT: status_block_write_inbound_queue_threshold_counter + 0x480a0, // NOLINT: status_block_write_inbound_queue_insertion_counter + 0x480a8, // NOLINT: status_block_write_inbound_queue_full_counter + 0x480b0, // NOLINT: queue_fetch_inbound_queue_total_occupancy + 0x480b8, // NOLINT: queue_fetch_inbound_queue_threshold_counter + 0x480c0, // NOLINT: queue_fetch_inbound_queue_insertion_counter + 0x480c8, // NOLINT: queue_fetch_inbound_queue_full_counter + 0x480d0, // NOLINT: instruction_outbound_queue_total_occupancy + 0x480d8, // NOLINT: instruction_outbound_queue_threshold_counter + 0x480e0, // NOLINT: instruction_outbound_queue_insertion_counter + 0x480e8, // NOLINT: instruction_outbound_queue_full_counter + 0x480f0, // NOLINT: input_actv_outbound_queue_total_occupancy + 0x480f8, // NOLINT: input_actv_outbound_queue_threshold_counter + 0x48100, // NOLINT: input_actv_outbound_queue_insertion_counter + 0x48108, // NOLINT: input_actv_outbound_queue_full_counter + 0x48110, // NOLINT: param_outbound_queue_total_occupancy + 0x48118, // NOLINT: param_outbound_queue_threshold_counter + 0x48120, // NOLINT: param_outbound_queue_insertion_counter + 0x48128, // NOLINT: param_outbound_queue_full_counter + 0x48130, // NOLINT: output_actv_outbound_queue_total_occupancy + 0x48138, // NOLINT: output_actv_outbound_queue_threshold_counter + 0x48140, // NOLINT: output_actv_outbound_queue_insertion_counter + 0x48148, // NOLINT: output_actv_outbound_queue_full_counter + 0x48150, // NOLINT: status_block_write_outbound_queue_total_occupancy + 0x48158, // NOLINT: status_block_write_outbound_queue_threshold_counter + 0x48160, // NOLINT: status_block_write_outbound_queue_insertion_counter + 0x48168, // NOLINT: status_block_write_outbound_queue_full_counter + 0x48170, // NOLINT: queue_fetch_outbound_queue_total_occupancy + 0x48178, // NOLINT: queue_fetch_outbound_queue_threshold_counter + 0x48180, // NOLINT: queue_fetch_outbound_queue_insertion_counter + 0x48188, // NOLINT: queue_fetch_outbound_queue_full_counter + 0x48190, // NOLINT: page_table_request_outbound_queue_total_occupancy + 0x48198, // NOLINT: page_table_request_outbound_queue_threshold_counter + 0x481a0, // NOLINT: page_table_request_outbound_queue_insertion_counter + 0x481a8, // NOLINT: page_table_request_outbound_queue_full_counter + 0x481b0, // NOLINT: read_tracking_fifo_total_occupancy + 0x481b8, // NOLINT: read_tracking_fifo_threshold_counter + 0x481c0, // NOLINT: read_tracking_fifo_insertion_counter + 0x481c8, // NOLINT: read_tracking_fifo_full_counter + 0x481d0, // NOLINT: write_tracking_fifo_total_occupancy + 0x481d8, // NOLINT: write_tracking_fifo_threshold_counter + 0x481e0, // NOLINT: write_tracking_fifo_insertion_counter + 0x481e8, // NOLINT: write_tracking_fifo_full_counter + 0x481f0, // NOLINT: read_buffer_total_occupancy + 0x481f8, // NOLINT: read_buffer_threshold_counter + 0x48200, // NOLINT: read_buffer_insertion_counter + 0x48208, // NOLINT: read_buffer_full_counter + 0x48210, // NOLINT: axi_aw_credit_shim_total_occupancy + 0x48218, // NOLINT: axi_aw_credit_shim_threshold_counter + 0x48220, // NOLINT: axi_aw_credit_shim_insertion_counter + 0x48228, // NOLINT: axi_aw_credit_shim_full_counter + 0x48230, // NOLINT: axi_ar_credit_shim_total_occupancy + 0x48238, // NOLINT: axi_ar_credit_shim_threshold_counter + 0x48240, // NOLINT: axi_ar_credit_shim_insertion_counter + 0x48248, // NOLINT: axi_ar_credit_shim_full_counter + 0x48250, // NOLINT: axi_w_credit_shim_total_occupancy + 0x48258, // NOLINT: axi_w_credit_shim_threshold_counter + 0x48260, // NOLINT: axi_w_credit_shim_insertion_counter + 0x48268, // NOLINT: axi_w_credit_shim_full_counter + 0x48270, // NOLINT: instruction_inbound_queue_empty_cycles_count + 0x48278, // NOLINT: input_actv_inbound_queue_empty_cycles_count + 0x48280, // NOLINT: param_inbound_queue_empty_cycles_count + 0x48288, // NOLINT: output_actv_inbound_queue_empty_cycles_count + 0x48290, // NOLINT: status_block_write_inbound_queue_empty_cycles_count + 0x48298, // NOLINT: queue_fetch_inbound_queue_empty_cycles_count + 0x482a0, // NOLINT: instruction_outbound_queue_empty_cycles_count + 0x482a8, // NOLINT: input_actv_outbound_queue_empty_cycles_count + 0x482b0, // NOLINT: param_outbound_queue_empty_cycles_count + 0x482b8, // NOLINT: output_actv_outbound_queue_empty_cycles_count + 0x482c0, // NOLINT: status_block_write_outbound_queue_empty_cycles_count + 0x482c8, // NOLINT: queue_fetch_outbound_queue_empty_cycles_count + 0x482d0, // NOLINT: page_table_request_outbound_queue_empty_cycles_count + 0x482d8, // NOLINT: read_tracking_fifo_empty_cycles_count + 0x482e0, // NOLINT: write_tracking_fifo_empty_cycles_count + 0x482e8, // NOLINT: read_buffer_empty_cycles_count + 0x482f0, // NOLINT: read_request_arbiter_instruction_request_cycles + 0x482f8, // NOLINT: read_request_arbiter_instruction_blocked_cycles + 0x48300, // NOLINT: + // read_request_arbiter_instruction_blocked_by_arbitration_cycles + 0x48308, // NOLINT: + // read_request_arbiter_instruction_cycles_blocked_over_threshold + 0x48310, // NOLINT: read_request_arbiter_input_actv_request_cycles + 0x48318, // NOLINT: read_request_arbiter_input_actv_blocked_cycles + 0x48320, // NOLINT: + // read_request_arbiter_input_actv_blocked_by_arbitration_cycles + 0x48328, // NOLINT: + // read_request_arbiter_input_actv_cycles_blocked_over_threshold + 0x48330, // NOLINT: read_request_arbiter_param_request_cycles + 0x48338, // NOLINT: read_request_arbiter_param_blocked_cycles + 0x48340, // NOLINT: + // read_request_arbiter_param_blocked_by_arbitration_cycles + 0x48348, // NOLINT: + // read_request_arbiter_param_cycles_blocked_over_threshold + 0x48350, // NOLINT: read_request_arbiter_queue_fetch_request_cycles + 0x48358, // NOLINT: read_request_arbiter_queue_fetch_blocked_cycles + 0x48360, // NOLINT: + // read_request_arbiter_queue_fetch_blocked_by_arbitration_cycles + 0x48368, // NOLINT: + // read_request_arbiter_queue_fetch_cycles_blocked_over_threshold + 0x48370, // NOLINT: read_request_arbiter_page_table_request_request_cycles + 0x48378, // NOLINT: read_request_arbiter_page_table_request_blocked_cycles + 0x48380, // NOLINT: + // read_request_arbiter_page_table_request_blocked_by_arbitration_cycles + 0x48388, // NOLINT: + // read_request_arbiter_page_table_request_cycles_blocked_over_threshold + 0x48390, // NOLINT: write_request_arbiter_output_actv_request_cycles + 0x48398, // NOLINT: write_request_arbiter_output_actv_blocked_cycles + 0x483a0, // NOLINT: + // write_request_arbiter_output_actv_blocked_by_arbitration_cycles + 0x483a8, // NOLINT: + // write_request_arbiter_output_actv_cycles_blocked_over_threshold + 0x483b0, // NOLINT: write_request_arbiter_status_block_write_request_cycles + 0x483b8, // NOLINT: write_request_arbiter_status_block_write_blocked_cycles + 0x483c0, // NOLINT: + // write_request_arbiter_status_block_write_blocked_by_arbitration_cycles + 0x483c8, // NOLINT: + // write_request_arbiter_status_block_write_cycles_blocked_over_threshold + 0x483d0, // NOLINT: address_translation_arbiter_instruction_request_cycles + 0x483d8, // NOLINT: address_translation_arbiter_instruction_blocked_cycles + 0x483e0, // NOLINT: + // address_translation_arbiter_instruction_blocked_by_arbitration_cycles + 0x483e8, // NOLINT: + // address_translation_arbiter_instruction_cycles_blocked_over_threshold + 0x483f0, // NOLINT: address_translation_arbiter_input_actv_request_cycles + 0x483f8, // NOLINT: address_translation_arbiter_input_actv_blocked_cycles + 0x48400, // NOLINT: + // address_translation_arbiter_input_actv_blocked_by_arbitration_cycles + 0x48408, // NOLINT: + // address_translation_arbiter_input_actv_cycles_blocked_over_threshold + 0x48410, // NOLINT: address_translation_arbiter_param_request_cycles + 0x48418, // NOLINT: address_translation_arbiter_param_blocked_cycles + 0x48420, // NOLINT: + // address_translation_arbiter_param_blocked_by_arbitration_cycles + 0x48428, // NOLINT: + // address_translation_arbiter_param_cycles_blocked_over_threshold + 0x48430, // NOLINT: + // address_translation_arbiter_status_block_write_request_cycles + 0x48438, // NOLINT: + // address_translation_arbiter_status_block_write_blocked_cycles + 0x48440, // NOLINT: + // address_translation_arbiter_status_block_write_blocked_by_arbitration_cycles + 0x48448, // NOLINT: + // address_translation_arbiter_status_block_write_cycles_blocked_over_threshold + 0x48450, // NOLINT: address_translation_arbiter_output_actv_request_cycles + 0x48458, // NOLINT: address_translation_arbiter_output_actv_blocked_cycles + 0x48460, // NOLINT: + // address_translation_arbiter_output_actv_blocked_by_arbitration_cycles + 0x48468, // NOLINT: + // address_translation_arbiter_output_actv_cycles_blocked_over_threshold + 0x48470, // NOLINT: address_translation_arbiter_queue_fetch_request_cycles + 0x48478, // NOLINT: address_translation_arbiter_queue_fetch_blocked_cycles + 0x48480, // NOLINT: + // address_translation_arbiter_queue_fetch_blocked_by_arbitration_cycles + 0x48488, // NOLINT: + // address_translation_arbiter_queue_fetch_cycles_blocked_over_threshold + 0x48490, // NOLINT: issued_interrupt_count + 0x48498, // NOLINT: data_read_16byte_count + 0x484a0, // NOLINT: waiting_for_tag_cycles + 0x484a8, // NOLINT: waiting_for_axi_cycles + 0x484b0, // NOLINT: simple_translations + 0x484c8, // NOLINT: instruction_credits_per_cycle_sum + 0x484d0, // NOLINT: input_actv_credits_per_cycle_sum + 0x484d8, // NOLINT: param_credits_per_cycle_sum + 0x484e0, // NOLINT: output_actv_credits_per_cycle_sum + 0x484e8, // NOLINT: status_block_write_credits_per_cycle_sum + 0x484f0, // NOLINT: queue_fetch_credits_per_cycle_sum + 0x484f8, // NOLINT: page_table_request_credits_per_cycle_sum + 0x48500, // NOLINT: output_actv_queue_control + 0x48508, // NOLINT: output_actv_queue_status + 0x48510, // NOLINT: output_actv_queue_descriptor_size + 0x48518, // NOLINT: output_actv_queue_minimum_size + 0x48520, // NOLINT: output_actv_queue_maximum_size + 0x48528, // NOLINT: output_actv_queue_base + 0x48530, // NOLINT: output_actv_queue_status_block_base + 0x48538, // NOLINT: output_actv_queue_size + 0x48540, // NOLINT: output_actv_queue_tail + 0x48548, // NOLINT: output_actv_queue_fetched_head + 0x48550, // NOLINT: output_actv_queue_completed_head + 0x48558, // NOLINT: output_actv_queue_int_control + 0x48560, // NOLINT: output_actv_queue_int_status + 0x48568, // NOLINT: instruction_queue_control + 0x48570, // NOLINT: instruction_queue_status + 0x48578, // NOLINT: instruction_queue_descriptor_size + 0x48580, // NOLINT: instruction_queue_minimum_size + 0x48588, // NOLINT: instruction_queue_maximum_size + 0x48590, // NOLINT: instruction_queue_base + 0x48598, // NOLINT: instruction_queue_status_block_base + 0x485a0, // NOLINT: instruction_queue_size + 0x485a8, // NOLINT: instruction_queue_tail + 0x485b0, // NOLINT: instruction_queue_fetched_head + 0x485b8, // NOLINT: instruction_queue_completed_head + 0x485c0, // NOLINT: instruction_queue_int_control + 0x485c8, // NOLINT: instruction_queue_int_status + 0x485d0, // NOLINT: input_actv_queue_control + 0x485d8, // NOLINT: input_actv_queue_status + 0x485e0, // NOLINT: input_actv_queue_descriptor_size + 0x485e8, // NOLINT: input_actv_queue_minimum_size + 0x485f0, // NOLINT: input_actv_queue_maximum_size + 0x485f8, // NOLINT: input_actv_queue_base + 0x48600, // NOLINT: input_actv_queue_status_block_base + 0x48608, // NOLINT: input_actv_queue_size + 0x48610, // NOLINT: input_actv_queue_tail + 0x48618, // NOLINT: input_actv_queue_fetched_head + 0x48620, // NOLINT: input_actv_queue_completed_head + 0x48628, // NOLINT: input_actv_queue_int_control + 0x48630, // NOLINT: input_actv_queue_int_status + 0x48638, // NOLINT: param_queue_control + 0x48640, // NOLINT: param_queue_status + 0x48648, // NOLINT: param_queue_descriptor_size + 0x48650, // NOLINT: param_queue_minimum_size + 0x48658, // NOLINT: param_queue_maximum_size + 0x48660, // NOLINT: param_queue_base + 0x48668, // NOLINT: param_queue_status_block_base + 0x48670, // NOLINT: param_queue_size + 0x48678, // NOLINT: param_queue_tail + 0x48680, // NOLINT: param_queue_fetched_head + 0x48688, // NOLINT: param_queue_completed_head + 0x48690, // NOLINT: param_queue_int_control + 0x48698, // NOLINT: param_queue_int_status + 0x486a0, // NOLINT: sc_host_int_control + 0x486a8, // NOLINT: sc_host_int_status + 0x486b0, // NOLINT: top_level_int_control + 0x486b8, // NOLINT: top_level_int_status + 0x486c0, // NOLINT: fatal_err_int_control + 0x486c8, // NOLINT: fatal_err_int_status + 0x486d0, // NOLINT: sc_host_int_count + 0x486d8, // NOLINT: dma_pause + 0x486e0, // NOLINT: dma_paused + 0x486e8, // NOLINT: status_block_update + 0x486f0, // NOLINT: hib_error_status + 0x486f8, // NOLINT: hib_error_mask + 0x48700, // NOLINT: hib_first_error_status + 0x48708, // NOLINT: hib_first_error_timestamp + 0x48710, // NOLINT: hib_inject_error + 0x48718, // NOLINT: read_request_arbiter + 0x48720, // NOLINT: write_request_arbiter + 0x48728, // NOLINT: address_translation_arbiter + 0x48730, // NOLINT: sender_queue_threshold + 0x48738, // NOLINT: page_fault_address + 0x48740, // NOLINT: instruction_credits + 0x48748, // NOLINT: input_actv_credits + 0x48750, // NOLINT: param_credits + 0x48758, // NOLINT: output_actv_credits + 0x48760, // NOLINT: pause_state + 0x48768, // NOLINT: snapshot + 0x48770, // NOLINT: idle_assert + 0x48778, // NOLINT: wire_int_pending_bit_array + 0x48788, // NOLINT: tileconfig0 + 0x48790, // NOLINT: tileconfig1 +}; + +const DebugScalarCoreCsrOffsets kBeagleDebugScalarCoreCsrOffsets = { + 0x44000, // NOLINT: topology + 0x44008, // NOLINT: scMemoryCapacity + 0x44010, // NOLINT: tileMemoryCapacity + 0x44040, // NOLINT: scMemoryAccess + 0x44048, // NOLINT: scMemoryData + 0x44288, // NOLINT: Timeout + 0x44260, // NOLINT: Error_ScalarCore + 0x44268, // NOLINT: Error_Mask_ScalarCore + 0x44270, // NOLINT: Error_Force_ScalarCore + 0x44278, // NOLINT: Error_Timestamp_ScalarCore + 0x44280, // NOLINT: Error_Info_ScalarCore + 0x44018, // NOLINT: scalarCoreRunControl + 0x44020, // NOLINT: scalarCoreBreakPoint + 0x44028, // NOLINT: currentPc + 0x44038, // NOLINT: executeControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, currentPc_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, currentPc_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_2BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, currentPc_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_3RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_3BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, currentPc_3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_3 + 0x44050, // NOLINT: SyncCounter_AVDATA_POP + 0x44058, // NOLINT: SyncCounter_PARAMETER_POP + 0x44060, // NOLINT: SyncCounter_AVDATA_INFEED + 0x44068, // NOLINT: SyncCounter_PARAMETER_INFEED + 0x44070, // NOLINT: SyncCounter_SCALAR_INFEED + 0x44078, // NOLINT: SyncCounter_PRODUCER_A + 0x44080, // NOLINT: SyncCounter_PRODUCER_B + 0x44088, // NOLINT: SyncCounter_RING_OUTFEED + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_POP_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_POP_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_INFEED_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_INFEED_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_SCALAR_INFEED_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_A_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_B_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_RING_OUTFEED_0_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_POP_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_POP_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_INFEED_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_INFEED_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_SCALAR_INFEED_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_A_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_B_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_RING_OUTFEED_1_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_POP_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_POP_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_INFEED_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_INFEED_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_SCALAR_INFEED_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_A_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_B_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_RING_OUTFEED_2_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_POP_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_POP_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_AVDATA_INFEED_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PARAMETER_INFEED_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_SCALAR_INFEED_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_A_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_PRODUCER_B_3_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, SyncCounter_RING_OUTFEED_3_0 + 0x44158, // NOLINT: avDataPopRunControl + 0x44160, // NOLINT: avDataPopBreakPoint + 0x44168, // NOLINT: avDataPopRunStatus + 0x44170, // NOLINT: avDataPopOverwriteMode + 0x44178, // NOLINT: avDataPopEnableTracing + 0x44180, // NOLINT: avDataPopStartCycle + 0x44188, // NOLINT: avDataPopEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPopStallCycleCount + 0x44190, // NOLINT: avDataPopProgramCounter + 0x442a0, // NOLINT: avDataPopTtuStateRegFile + 0x442c0, // NOLINT: avDataPopTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3Trace + 0x44198, // NOLINT: parameterPopRunControl + 0x441a0, // NOLINT: parameterPopBreakPoint + 0x441a8, // NOLINT: parameterPopRunStatus + 0x441b0, // NOLINT: parameterPopOverwriteMode + 0x441b8, // NOLINT: parameterPopEnableTracing + 0x441c0, // NOLINT: parameterPopStartCycle + 0x441c8, // NOLINT: parameterPopEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPopStallCycleCount + 0x441d0, // NOLINT: parameterPopProgramCounter + 0x442e0, // NOLINT: parameterPopTtuStateRegFile + 0x44300, // NOLINT: parameterPopTrace + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3Trace + 0x441d8, // NOLINT: infeedRunControl + 0x441e0, // NOLINT: infeedRunStatus + 0x441e8, // NOLINT: infeedBreakPoint + 0x441f0, // NOLINT: infeedOverwriteMode + 0x441f8, // NOLINT: infeedEnableTracing + 0x44200, // NOLINT: infeedStartCycle + 0x44208, // NOLINT: infeedEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeedStallCycleCount + 0x44210, // NOLINT: infeedProgramCounter + 0x44320, // NOLINT: infeedTtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1TtuStateRegFile + 0x44218, // NOLINT: outfeedRunControl + 0x44220, // NOLINT: outfeedRunStatus + 0x44228, // NOLINT: outfeedBreakPoint + 0x44230, // NOLINT: outfeedOverwriteMode + 0x44238, // NOLINT: outfeedEnableTracing + 0x44240, // NOLINT: outfeedStartCycle + 0x44248, // NOLINT: outfeedEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeedStallCycleCount + 0x44250, // NOLINT: outfeedProgramCounter + 0x44360, // NOLINT: outfeedTtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0TtuStateRegFile + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1RunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1BreakPoint + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1OverwriteMode + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1EnableTracing + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1StartCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1StallCycleCount + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1ProgramCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1TtuStateRegFile + 0x44258, // NOLINT: scalarCoreRunStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarCoreRunStatus_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarCoreRunStatus_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarCoreRunStatus_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarCoreRunStatus_3 +}; + +const DebugTileCsrOffsets kBeagleDebugTileCsrOffsets = { + kCsrRegisterSpaceInvalidOffset, // UNUSED, TileClockControl + 0x42000, // NOLINT: tileid + 0x42008, // NOLINT: scratchpad + 0x42010, // NOLINT: memoryAccess + 0x42018, // NOLINT: memoryData + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_3 + 0x42020, // NOLINT: deepSleep + 0x42028, // NOLINT: SyncCounter_AVDATA + 0x42030, // NOLINT: SyncCounter_PARAMETERS + 0x42038, // NOLINT: SyncCounter_PARTIAL_SUMS + 0x42040, // NOLINT: SyncCounter_MESH_NORTH_IN + 0x42048, // NOLINT: SyncCounter_MESH_EAST_IN + 0x42050, // NOLINT: SyncCounter_MESH_SOUTH_IN + 0x42058, // NOLINT: SyncCounter_MESH_WEST_IN + 0x42060, // NOLINT: SyncCounter_MESH_NORTH_OUT + 0x42068, // NOLINT: SyncCounter_MESH_EAST_OUT + 0x42070, // NOLINT: SyncCounter_MESH_SOUTH_OUT + 0x42078, // NOLINT: SyncCounter_MESH_WEST_OUT + 0x42080, // NOLINT: SyncCounter_WIDE_TO_NARROW + 0x42088, // NOLINT: SyncCounter_WIDE_TO_SCALING + 0x42090, // NOLINT: SyncCounter_NARROW_TO_WIDE + 0x42098, // NOLINT: SyncCounter_RING_READ_A + 0x420a0, // NOLINT: SyncCounter_RING_READ_B + 0x420a8, // NOLINT: SyncCounter_RING_WRITE + 0x420b0, // NOLINT: SyncCounter_RING_PRODUCER_A + 0x420b8, // NOLINT: SyncCounter_RING_PRODUCER_B + 0x420c0, // NOLINT: opRunControl + 0x420c8, // NOLINT: PowerSaveData + 0x420d0, // NOLINT: opBreakPoint + 0x420d8, // NOLINT: StallCounter + 0x420e0, // NOLINT: opRunStatus + 0x420e8, // NOLINT: OpOverwriteMode + 0x420f0, // NOLINT: OpEnableTracing + 0x420f8, // NOLINT: OpStartCycle + 0x42100, // NOLINT: OpEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, OpStallCycleCount + 0x42108, // NOLINT: OpProgramCounter + 0x42110, // NOLINT: wideToNarrowRunControl + 0x42118, // NOLINT: wideToNarrowRunStatus + 0x42120, // NOLINT: wideToNarrowBreakPoint + 0x42128, // NOLINT: dmaWideToNarrowOverwriteMode + 0x42130, // NOLINT: dmaWideToNarrowEnableTracing + 0x42138, // NOLINT: dmaWideToNarrowStartCycle + 0x42140, // NOLINT: dmaWideToNarrowEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaWideToNarrowStallCycleCount + 0x42148, // NOLINT: dmaWideToNarrowProgramCounter + 0x42150, // NOLINT: narrowToWideRunControl + 0x42158, // NOLINT: narrowToWideRunStatus + 0x42160, // NOLINT: narrowToWideBreakPoint + 0x42168, // NOLINT: dmaNarrowToWideOverwriteMode + 0x42170, // NOLINT: dmaNarrowToWideEnableTracing + 0x42178, // NOLINT: dmaNarrowToWideStartCycle + 0x42180, // NOLINT: dmaNarrowToWideEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaNarrowToWideStallCycleCount + 0x42188, // NOLINT: dmaNarrowToWideProgramCounter + 0x42190, // NOLINT: ringBusConsumer0RunControl + 0x42198, // NOLINT: ringBusConsumer0RunStatus + 0x421a0, // NOLINT: ringBusConsumer0BreakPoint + 0x421a8, // NOLINT: dmaRingBusConsumer0OverwriteMode + 0x421b0, // NOLINT: dmaRingBusConsumer0EnableTracing + 0x421b8, // NOLINT: dmaRingBusConsumer0StartCycle + 0x421c0, // NOLINT: dmaRingBusConsumer0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusConsumer0StallCycleCount + 0x421c8, // NOLINT: dmaRingBusConsumer0ProgramCounter + 0x421d0, // NOLINT: ringBusConsumer1RunControl + 0x421d8, // NOLINT: ringBusConsumer1RunStatus + 0x421e0, // NOLINT: ringBusConsumer1BreakPoint + 0x421e8, // NOLINT: dmaRingBusConsumer1OverwriteMode + 0x421f0, // NOLINT: dmaRingBusConsumer1EnableTracing + 0x421f8, // NOLINT: dmaRingBusConsumer1StartCycle + 0x42200, // NOLINT: dmaRingBusConsumer1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusConsumer1StallCycleCount + 0x42208, // NOLINT: dmaRingBusConsumer1ProgramCounter + 0x42210, // NOLINT: ringBusProducerRunControl + 0x42218, // NOLINT: ringBusProducerRunStatus + 0x42220, // NOLINT: ringBusProducerBreakPoint + 0x42228, // NOLINT: dmaRingBusProducerOverwriteMode + 0x42230, // NOLINT: dmaRingBusProducerEnableTracing + 0x42238, // NOLINT: dmaRingBusProducerStartCycle + 0x42240, // NOLINT: dmaRingBusProducerEndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, + // dmaRingBusProducerStallCycleCount + 0x42248, // NOLINT: dmaRingBusProducerProgramCounter + 0x42250, // NOLINT: meshBus0RunControl + 0x42258, // NOLINT: meshBus0RunStatus + 0x42260, // NOLINT: meshBus0BreakPoint + 0x42270, // NOLINT: dmaMeshBus0OverwriteMode + 0x42278, // NOLINT: dmaMeshBus0EnableTracing + 0x42280, // NOLINT: dmaMeshBus0StartCycle + 0x42288, // NOLINT: dmaMeshBus0EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus0StallCycleCount + 0x42290, // NOLINT: dmaMeshBus0ProgramCounter + 0x42298, // NOLINT: meshBus1RunControl + 0x422a0, // NOLINT: meshBus1RunStatus + 0x422a8, // NOLINT: meshBus1BreakPoint + 0x422b8, // NOLINT: dmaMeshBus1OverwriteMode + 0x422c0, // NOLINT: dmaMeshBus1EnableTracing + 0x422c8, // NOLINT: dmaMeshBus1StartCycle + 0x422d0, // NOLINT: dmaMeshBus1EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus1StallCycleCount + 0x422d8, // NOLINT: dmaMeshBus1ProgramCounter + 0x422e0, // NOLINT: meshBus2RunControl + 0x422e8, // NOLINT: meshBus2RunStatus + 0x422f0, // NOLINT: meshBus2BreakPoint + 0x42300, // NOLINT: dmaMeshBus2OverwriteMode + 0x42308, // NOLINT: dmaMeshBus2EnableTracing + 0x42310, // NOLINT: dmaMeshBus2StartCycle + 0x42318, // NOLINT: dmaMeshBus2EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus2StallCycleCount + 0x42320, // NOLINT: dmaMeshBus2ProgramCounter + 0x42328, // NOLINT: meshBus3RunControl + 0x42330, // NOLINT: meshBus3RunStatus + 0x42338, // NOLINT: meshBus3BreakPoint + 0x42348, // NOLINT: dmaMeshBus3OverwriteMode + 0x42350, // NOLINT: dmaMeshBus3EnableTracing + 0x42358, // NOLINT: dmaMeshBus3StartCycle + 0x42360, // NOLINT: dmaMeshBus3EndCycle + kCsrRegisterSpaceInvalidOffset, // UNUSED, dmaMeshBus3StallCycleCount + 0x42368, // NOLINT: dmaMeshBus3ProgramCounter + 0x42370, // NOLINT: Error_Tile + 0x42378, // NOLINT: Error_Mask_Tile + 0x42380, // NOLINT: Error_Force_Tile + 0x42388, // NOLINT: Error_Timestamp_Tile + 0x42390, // NOLINT: Error_Info_Tile + 0x42398, // NOLINT: Timeout + 0x423c0, // NOLINT: opTtuStateRegFile + 0x42400, // NOLINT: OpTrace + 0x42480, // NOLINT: wideToNarrowTtuStateRegFile + 0x42500, // NOLINT: dmaWideToNarrowTrace + 0x42580, // NOLINT: narrowToWideTtuStateRegFile + 0x42600, // NOLINT: dmaNarrowToWideTrace + 0x42620, // NOLINT: ringBusConsumer0TtuStateRegFile + 0x42640, // NOLINT: dmaRingBusConsumer0Trace + 0x42660, // NOLINT: ringBusConsumer1TtuStateRegFile + 0x42680, // NOLINT: dmaRingBusConsumer1Trace + 0x426a0, // NOLINT: ringBusProducerTtuStateRegFile + 0x426c0, // NOLINT: dmaRingBusProducerTrace + 0x42700, // NOLINT: meshBus0TtuStateRegFile + 0x42740, // NOLINT: dmaMeshBus0Trace + 0x42780, // NOLINT: meshBus1TtuStateRegFile + 0x427c0, // NOLINT: dmaMeshBus1Trace + 0x42800, // NOLINT: meshBus2TtuStateRegFile + 0x42840, // NOLINT: dmaMeshBus2Trace + 0x42880, // NOLINT: meshBus3TtuStateRegFile + 0x428c0, // NOLINT: dmaMeshBus3Trace + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryIsolation + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryRetention +}; + +const HibKernelCsrOffsets kBeagleHibKernelCsrOffsets = { + 0x46000, // NOLINT: page_table_size + 0x46008, // NOLINT: extended_table + 0x46050, // NOLINT: dma_pause + 0x46078, // NOLINT: page_table_init + 0x46080, // NOLINT: msix_table_init + 0x50000, // NOLINT: page_table + kCsrRegisterSpaceInvalidOffset, // UNUSED, dma_burst_limiter +}; + +const HibUserCsrOffsets kBeagleHibUserCsrOffsets = { + 0x486b0, // NOLINT: top_level_int_control + 0x486b8, // NOLINT: top_level_int_status + 0x486d0, // NOLINT: sc_host_int_count + 0x486d8, // NOLINT: dma_pause + 0x486e0, // NOLINT: dma_paused + 0x486e8, // NOLINT: status_block_update + 0x486f0, // NOLINT: hib_error_status + 0x486f8, // NOLINT: hib_error_mask + 0x48700, // NOLINT: hib_first_error_status + 0x48708, // NOLINT: hib_first_error_timestamp + 0x48710, // NOLINT: hib_inject_error + 0x487a8, // NOLINT: dma_burst_limiter +}; + +const QueueCsrOffsets kBeagleInstructionQueueCsrOffsets = { + 0x48568, // NOLINT: instruction_queue_control + 0x48570, // NOLINT: instruction_queue_status + 0x48578, // NOLINT: instruction_queue_descriptor_size + 0x48590, // NOLINT: instruction_queue_base + 0x48598, // NOLINT: instruction_queue_status_block_base + 0x485a0, // NOLINT: instruction_queue_size + 0x485a8, // NOLINT: instruction_queue_tail + 0x485b0, // NOLINT: instruction_queue_fetched_head + 0x485b8, // NOLINT: instruction_queue_completed_head + 0x485c0, // NOLINT: instruction_queue_int_control + 0x485c8, // NOLINT: instruction_queue_int_status + 0x48580, // NOLINT: instruction_queue_minimum_size + 0x48588, // NOLINT: instruction_queue_maximum_size + 0x46018, // NOLINT: instruction_queue_int_vector +}; + +const MemoryCsrOffsets kBeagleMemoryMemoryCsrOffsets = { + 0x42010, // NOLINT: memoryAccess + 0x42018, // NOLINT: memoryData +}; + +const ScalarCoreCsrOffsets kBeagleScalarCoreCsrOffsets = { + 0x44018, // NOLINT: scalarCoreRunControl + 0x44038, // NOLINT: executeControl + 0x44158, // NOLINT: avDataPopRunControl + 0x44198, // NOLINT: parameterPopRunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_2RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, scalarDatapath_3RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, executeControl_3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, avDataPop_3RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, parameterPop_3RunControl + 0x441d8, // NOLINT: infeedRunControl + 0x44218, // NOLINT: outfeedRunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, contextControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, contextStatus + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_0_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_0_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_1_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_1_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_2_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_2_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_0RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, infeed_3_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, outfeed_3_1RunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, TilePowerInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakPowerSampleInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpPowerSampleInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtPowerSampleInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakSampleAccumulator + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpSampleAccumulator + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtSampleAccumulator + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakThreshold0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakThreshold1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakThreshold2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakThreshold3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpThreshold0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpThreshold1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpThreshold2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpThreshold3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtThreshold0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakActionTable + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpActionTable + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtActionTable + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakRunningSum + kCsrRegisterSpaceInvalidOffset, // UNUSED, peakRunningSumInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpRunningSum + kCsrRegisterSpaceInvalidOffset, // UNUSED, tdpRunningSumInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtRunningSum + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtRunningSumInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtDifference + kCsrRegisterSpaceInvalidOffset, // UNUSED, packageTdpAction + kCsrRegisterSpaceInvalidOffset, // UNUSED, ThrottleStallCounter + kCsrRegisterSpaceInvalidOffset, // UNUSED, cycleCount +}; + +const MemoryCsrOffsets kBeagleScmemoryMemoryCsrOffsets = { + 0x44040, // NOLINT: scMemoryAccess + 0x44048, // NOLINT: scMemoryData +}; + +const TileConfigCsrOffsets kBeagleTileConfigCsrOffsets = { + 0x48788, // NOLINT: tileconfig0 + 0x48790, // NOLINT: tileconfig1 +}; + +const TileCsrOffsets kBeagleTileCsrOffsets = { + 0x400c0, // NOLINT: opRunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToNarrowRunControl + 0x40150, // NOLINT: narrowToWideRunControl + 0x40110, // NOLINT: wideToNarrowRunControl + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_3 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_4 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_4 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_4 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_5 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_5 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_5 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_6 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_6 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_6 + kCsrRegisterSpaceInvalidOffset, // UNUSED, opRunControl_7 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowToWideRunControl_7 + kCsrRegisterSpaceInvalidOffset, // UNUSED, wideToNarrowRunControl_7 + 0x40190, // NOLINT: ringBusConsumer0RunControl + 0x401d0, // NOLINT: ringBusConsumer1RunControl + 0x40210, // NOLINT: ringBusProducerRunControl + 0x40250, // NOLINT: meshBus0RunControl + 0x40298, // NOLINT: meshBus1RunControl + 0x402e0, // NOLINT: meshBus2RunControl + 0x40328, // NOLINT: meshBus3RunControl + 0x40020, // NOLINT: deepSleep + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryIsolation + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryRetention + kCsrRegisterSpaceInvalidOffset, // UNUSED, EnergyTable + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtSampleInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtRunningSumInterval + kCsrRegisterSpaceInvalidOffset, // UNUSED, opAccumulateRegister + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtRunningSumRegister + kCsrRegisterSpaceInvalidOffset, // UNUSED, didtThreshold0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_0 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_1 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_2 + kCsrRegisterSpaceInvalidOffset, // UNUSED, narrowMemoryContext_3 +}; + +const WireCsrOffsets kBeagleWireCsrOffsets = { + 0x48778, // NOLINT: wire_int_pending_bit_array + 0x48780, // NOLINT: wire_int_mask_array +}; + +const InterruptCsrOffsets kBeagleUsbFatalErrIntInterruptCsrOffsets = { + 0x4c060, // NOLINT: fatal_err_int_control + 0x4c068, // NOLINT: fatal_err_int_status +}; + +const InterruptCsrOffsets kBeagleUsbScHostInt0InterruptCsrOffsets = { + 0x4c0b0, // NOLINT: sc_host_int_0_control + 0x4c0b8, // NOLINT: sc_host_int_0_status +}; + +const InterruptCsrOffsets kBeagleUsbScHostInt1InterruptCsrOffsets = { + 0x4c0c8, // NOLINT: sc_host_int_1_control + 0x4c0d0, // NOLINT: sc_host_int_1_status +}; + +const InterruptCsrOffsets kBeagleUsbScHostInt2InterruptCsrOffsets = { + 0x4c0e0, // NOLINT: sc_host_int_2_control + 0x4c0e8, // NOLINT: sc_host_int_2_status +}; + +const InterruptCsrOffsets kBeagleUsbScHostInt3InterruptCsrOffsets = { + 0x4c0f8, // NOLINT: sc_host_int_3_control + 0x4c100, // NOLINT: sc_host_int_3_status +}; + +const InterruptCsrOffsets kBeagleUsbTopLevelInt0InterruptCsrOffsets = { + 0x4c070, // NOLINT: top_level_int_0_control + 0x4c078, // NOLINT: top_level_int_0_status +}; + +const InterruptCsrOffsets kBeagleUsbTopLevelInt1InterruptCsrOffsets = { + 0x4c080, // NOLINT: top_level_int_1_control + 0x4c088, // NOLINT: top_level_int_1_status +}; + +const InterruptCsrOffsets kBeagleUsbTopLevelInt2InterruptCsrOffsets = { + 0x4c090, // NOLINT: top_level_int_2_control + 0x4c098, // NOLINT: top_level_int_2_status +}; + +const InterruptCsrOffsets kBeagleUsbTopLevelInt3InterruptCsrOffsets = { + 0x4c0a0, // NOLINT: top_level_int_3_control + 0x4c0a8, // NOLINT: top_level_int_3_status +}; + +const ApexCsrOffsets kBeagleApexCsrOffsets = { + 0x1a000, // NOLINT: omc0_00 + 0x1a0d4, // NOLINT: omc0_d4 + 0x1a0d8, // NOLINT: omc0_d8 + 0x1a0dc, // NOLINT: omc0_dc + 0x1a600, // NOLINT: mst_abm_en + 0x1a500, // NOLINT: slv_abm_en + 0x1a558, // NOLINT: slv_err_resp_isr_mask + 0x1a658, // NOLINT: mst_err_resp_isr_mask + 0x1a640, // NOLINT: mst_wr_err_resp + 0x1a644, // NOLINT: mst_rd_err_resp + 0x1a540, // NOLINT: slv_wr_err_resp + 0x1a544, // NOLINT: slv_rd_err_resp + 0x1a704, // NOLINT: rambist_ctrl_1 + 0x1a200, // NOLINT: efuse_00 +}; + +const CbBridgeCsrOffsets kBeagleCbBridgeCsrOffsets = { + 0x19018, // NOLINT: bo0_fifo_status + 0x1901c, // NOLINT: bo1_fifo_status + 0x19020, // NOLINT: bo2_fifo_status + 0x19024, // NOLINT: bo3_fifo_status + 0x1907c, // NOLINT: gcbb_credit0 +}; + +const MiscCsrOffsets kBeagleMiscCsrOffsets = { + 0x4a000, // NOLINT: idleRegister +}; + +const MsixCsrOffsets kBeagleMsixCsrOffsets = { + 0x46018, // NOLINT: instruction_queue_int_vector + 0x46020, // NOLINT: input_actv_queue_int_vector + 0x46028, // NOLINT: param_queue_int_vector + 0x46030, // NOLINT: output_actv_queue_int_vector + 0x46040, // NOLINT: top_level_int_vector + 0x46038, // NOLINT: sc_host_int_vector + 0x46048, // NOLINT: fatal_err_int_vector + 0x46068, // NOLINT: msix_pending_bit_array0 + 0x46800, // NOLINT: msix_table +}; + +const ScuCsrOffsets kBeagleScuCsrOffsets = { + 0x1a30c, // NOLINT: scu_ctrl_0 + 0x1a310, // NOLINT: scu_ctrl_1 + 0x1a314, // NOLINT: scu_ctrl_2 + 0x1a318, // NOLINT: scu_ctrl_3 + 0x1a31c, // NOLINT: scu_ctrl_4 + 0x1a320, // NOLINT: scu_ctrl_5 + 0x1a32c, // NOLINT: scu_ctr_6 + 0x1a33c, // NOLINT: scu_ctr_7 +}; + +const UsbCsrOffsets kBeagleUsbCsrOffsets = { + 0x4c058, // NOLINT: outfeed_chunk_length + 0x4c148, // NOLINT: descr_ep + 0x4c150, // NOLINT: ep_status_credit + 0x4c160, // NOLINT: multi_bo_ep +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_BEAGLE_BEAGLE_CSR_OFFSETS_H_ diff --git a/driver/config/beagle_csr_helper.h b/driver/config/beagle_csr_helper.h new file mode 100644 index 0000000..474795a --- /dev/null +++ b/driver/config/beagle_csr_helper.h @@ -0,0 +1,544 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_BEAGLE_CSR_HELPER_H_ +#define DARWINN_DRIVER_CONFIG_BEAGLE_CSR_HELPER_H_ + +#include "driver/bitfield.h" +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { +namespace registers { + +// CSR helper to access fields for omc0_d4 CSR. +class Omc0D4 { + public: + // Defaults to reset value. + Omc0D4() : Omc0D4(0x1ULL) {} + explicit Omc0D4(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_method_sel(uint64 value) { reg_.method_sel_ = value; } + uint64 method_sel() const { return reg_.method_sel_(); } + + void set_thm_warn1(uint64 value) { reg_.thm_warn1_ = value; } + uint64 thm_warn1() const { return reg_.thm_warn1_(); } + + void set_thm_warn_en(uint64 value) { reg_.thm_warn_en_ = value; } + uint64 thm_warn_en() const { return reg_.thm_warn_en_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> method_sel_; + platforms::darwinn::driver::Bitfield<1, 15> field_01_; + platforms::darwinn::driver::Bitfield<16, 10> thm_warn1_; + platforms::darwinn::driver::Bitfield<26, 5> field_26_; + platforms::darwinn::driver::Bitfield<31, 1> thm_warn_en_; + } reg_; +}; + +// CSR helper to access fields for omc0_d8 CSR. +class Omc0D8 { + public: + // Defaults to reset value. + Omc0D8() : Omc0D8(0ULL) {} + explicit Omc0D8(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_enbg(uint64 value) { reg_.enbg_ = value; } + uint64 enbg() const { return reg_.enbg_(); } + + void set_envr(uint64 value) { reg_.envr_ = value; } + uint64 envr() const { return reg_.envr_(); } + + void set_enad(uint64 value) { reg_.enad_ = value; } + uint64 enad() const { return reg_.enad_(); } + + void set_thm_warn2(uint64 value) { reg_.thm_warn2_ = value; } + uint64 thm_warn2() const { return reg_.thm_warn2_(); } + + void set_sd_en(uint64 value) { reg_.sd_en_ = value; } + uint64 sd_en() const { return reg_.sd_en_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> enbg_; + platforms::darwinn::driver::Bitfield<1, 1> envr_; + platforms::darwinn::driver::Bitfield<2, 1> enad_; + platforms::darwinn::driver::Bitfield<3, 13> field_03_; + platforms::darwinn::driver::Bitfield<16, 10> thm_warn2_; + platforms::darwinn::driver::Bitfield<26, 5> field_26_; + platforms::darwinn::driver::Bitfield<31, 1> sd_en_; + } reg_; +}; + +// CSR helper to access fields for omc0_dc CSR. +class Omc0DC { + public: + // Defaults to reset value. + Omc0DC() : Omc0DC(0ULL) { set_data(0x3FF); } + explicit Omc0DC(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_data(uint64 value) { reg_.data_ = value; } + uint64 data() const { return reg_.data_(); } + + void set_sd_clear(uint64 value) { reg_.sd_clear_ = value; } + uint64 sd_clear() const { return reg_.sd_clear_(); } + + void set_warn_clear(uint64 value) { reg_.warn_clear_ = value; } + uint64 warn_clear() const { return reg_.warn_clear_(); } + + // Read-Only. + uint64 sd_o() const { return reg_.sd_o_(); } + uint64 warn_o() const { return reg_.warn_o_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> enthmc_; + platforms::darwinn::driver::Bitfield<1, 15> field_01_; + platforms::darwinn::driver::Bitfield<16, 10> data_; + platforms::darwinn::driver::Bitfield<26, 2> field_26_; + platforms::darwinn::driver::Bitfield<28, 1> sd_clear_; + platforms::darwinn::driver::Bitfield<29, 1> warn_clear_; + platforms::darwinn::driver::Bitfield<30, 1> sd_o_; + platforms::darwinn::driver::Bitfield<31, 1> warn_o_; + } reg_; +}; + +// CSR helper to access fields for rambist_ctrl_1 CSR. +class RamBistCtrl1 { + public: + // Defaults to reset value. + RamBistCtrl1() : RamBistCtrl1(0ULL) { + set_rg_rambist_gcbsel(0x1F); + set_rg_rambist_topsel(0x3); + set_rg_mbist_int_mask(0x7); + } + explicit RamBistCtrl1(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_rambist_gcbsel(uint64 value) { reg_.rg_rambist_gcbsel_ = value; } + uint64 rg_rambist_gcbsel() const { return reg_.rg_rambist_gcbsel_(); } + + void set_rg_rambist_topsel(uint64 value) { reg_.rg_rambist_topsel_ = value; } + uint64 rg_rambist_topsel() const { return reg_.rg_rambist_topsel_(); } + + void set_rg_rambist_tckmode(uint64 value) { + reg_.rg_rambist_tckmode_ = value; + } + uint64 rg_rambist_tckmode() const { return reg_.rg_rambist_tckmode_(); } + + void set_rg_rambist_req(uint64 value) { reg_.rg_rambist_req_ = value; } + uint64 rg_rambist_req() const { return reg_.rg_rambist_req_(); } + + void set_rg_tck_invert(uint64 value) { reg_.rg_tck_invert_ = value; } + uint64 rg_tck_invert() const { return reg_.rg_tck_invert_(); } + + void set_mbist_status(uint64 value) { reg_.mbist_status_ = value; } + uint64 mbist_status() const { return reg_.mbist_status_(); } + + void set_rg_mbist_int_status(uint64 value) { + reg_.rg_mbist_int_status_ = value; + } + uint64 rg_mbist_int_status() const { return reg_.rg_mbist_int_status_(); } + + void set_rg_mbist_int_mask(uint64 value) { reg_.rg_mbist_int_mask_ = value; } + uint64 rg_mbist_int_mask() const { return reg_.rg_mbist_int_mask_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 5> rg_rambist_gcbsel_; + platforms::darwinn::driver::Bitfield<5, 2> rg_rambist_topsel_; + platforms::darwinn::driver::Bitfield<7, 1> field_07_; + platforms::darwinn::driver::Bitfield<8, 1> rg_rambist_tckmode_; + platforms::darwinn::driver::Bitfield<9, 1> rg_rambist_req_; + platforms::darwinn::driver::Bitfield<10, 1> rg_tck_invert_; + platforms::darwinn::driver::Bitfield<11, 1> field_11_; + platforms::darwinn::driver::Bitfield<12, 2> mbist_status_; + platforms::darwinn::driver::Bitfield<14, 2> field_14_; + platforms::darwinn::driver::Bitfield<16, 3> rg_mbist_int_status_; + platforms::darwinn::driver::Bitfield<19, 1> field_19_; + platforms::darwinn::driver::Bitfield<20, 3> rg_mbist_int_mask_; + platforms::darwinn::driver::Bitfield<23, 9> field_23_; + } reg_; +}; + +// CSR helper to access fields for efuse_00 CSR. +class Efuse00 { + public: + // Defaults to reset value. + Efuse00() : Efuse00(0ULL) {} + explicit Efuse00(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_ef_int_mask(uint64 value) { reg_.ef_int_mask_ = value; } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> ef_single_step_dis_; + platforms::darwinn::driver::Bitfield<1, 2> ef_prod_sel_; + platforms::darwinn::driver::Bitfield<3, 3> ef_refclk_sel_ovr_; + platforms::darwinn::driver::Bitfield<6, 1> ef_pcie_gen1_link_; + platforms::darwinn::driver::Bitfield<7, 1> ef_usb_ssc_mode_0_; + platforms::darwinn::driver::Bitfield<8, 5> ef_i2caddr_ovr_; + platforms::darwinn::driver::Bitfield<13, 3> ef_psigma_; + platforms::darwinn::driver::Bitfield<16, 1> ef_mbist_dis_; + platforms::darwinn::driver::Bitfield<17, 1> ef_w_dis_; + platforms::darwinn::driver::Bitfield<18, 1> ef_thm_int_mask_; + platforms::darwinn::driver::Bitfield<19, 1> ef_int_mask_; + platforms::darwinn::driver::Bitfield<20, 2> ef_pwr_state_dis_; + platforms::darwinn::driver::Bitfield<22, 1> ef_usb_ssc_mode_1_; + platforms::darwinn::driver::Bitfield<23, 1> ef_8051_rom_500m_; + platforms::darwinn::driver::Bitfield<24, 8> ef_pll_M_; + } reg_; +}; + +// CSR helper to access fields for scu_ctrl_0 CSR. +class ScuCtrl0 { + public: + // Defaults to reset value. + ScuCtrl0() : ScuCtrl0(0ULL) { + set_rg_pllclk_sel(1); + set_rg_usb_slp_phy_mode(1); + set_rg_pcie_inact_phy_mode(1); + set_rg_usb_inact_phy_mode(1); + } + explicit ScuCtrl0(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_pllclk_sel(uint64 value) { reg_.rg_pllclk_sel_ = value; } + uint64 rg_pllclk_sel() const { return reg_.rg_pllclk_sel_(); } + + void set_rg_usb_slp_phy_mode(uint64 value) { + reg_.rg_usb_slp_phy_mode_ = value; + } + uint64 rg_usb_slp_phy_mode() const { return reg_.rg_usb_slp_phy_mode_(); } + + void set_rg_pcie_inact_phy_mode(uint64 value) { + reg_.rg_pcie_inact_phy_mode_ = value; + } + uint64 rg_pcie_inact_phy_mode() const { + return reg_.rg_pcie_inact_phy_mode_(); + } + + void set_rg_usb_inact_phy_mode(uint64 value) { + reg_.rg_usb_inact_phy_mode_ = value; + } + uint64 rg_usb_inact_phy_mode() const { return reg_.rg_usb_inact_phy_mode_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> rg_pllclk_sel_; + platforms::darwinn::driver::Bitfield<1, 1> rg_single_exit_; + platforms::darwinn::driver::Bitfield<2, 1> rg_single_link_rstn_; + platforms::darwinn::driver::Bitfield<3, 1> rg_sleep_chk_idle_; + platforms::darwinn::driver::Bitfield<4, 2> rg_pcie_slp_phy_mode_; + platforms::darwinn::driver::Bitfield<6, 2> rg_usb_slp_phy_mode_; + platforms::darwinn::driver::Bitfield<8, 3> rg_pcie_inact_phy_mode_; + platforms::darwinn::driver::Bitfield<11, 3> rg_usb_inact_phy_mode_; + platforms::darwinn::driver::Bitfield<14, 2> rg_mem_mode_dis_; + platforms::darwinn::driver::Bitfield<16, 1> rg_phy_prg_; + platforms::darwinn::driver::Bitfield<17, 1> bt_phy_prg_; + platforms::darwinn::driver::Bitfield<18, 1> bt_vbus_sel_; + platforms::darwinn::driver::Bitfield<19, 1> bt_bus_pwr_; + } reg_; +}; + +// CSR helper to access fields for scu_ctrl_2 CSR. +class ScuCtrl2 { + public: + // Defaults to reset value. + ScuCtrl2() : ScuCtrl2(0ULL) {} + explicit ScuCtrl2(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_gated_gcb(uint64 value) { reg_.rg_gated_gcb_ = value; } + uint64 rg_gated_gcb() const { return reg_.rg_gated_gcb_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> rg_rst_pcie_; + platforms::darwinn::driver::Bitfield<1, 1> rg_rst_pcie_axi_; + platforms::darwinn::driver::Bitfield<2, 2> rg_rst_gcb_; + platforms::darwinn::driver::Bitfield<4, 1> rg_rst_pcieslv_abm_; + platforms::darwinn::driver::Bitfield<5, 1> rg_rst_pciemst_abm_; + platforms::darwinn::driver::Bitfield<6, 1> rg_rst_omc_; + platforms::darwinn::driver::Bitfield<7, 1> rg_rst_mbist_; + platforms::darwinn::driver::Bitfield<8, 1> rg_rst_usb_; + platforms::darwinn::driver::Bitfield<9, 1> rg_rst_usb_subsys_; + platforms::darwinn::driver::Bitfield<10, 1> rg_rst_full_; + platforms::darwinn::driver::Bitfield<11, 1> rg_rst_link_; + platforms::darwinn::driver::Bitfield<12, 1> rg_rst_i2c_; + platforms::darwinn::driver::Bitfield<13, 1> rg_rst_scu_; + platforms::darwinn::driver::Bitfield<14, 1> rg_self_rst_subsys_; + platforms::darwinn::driver::Bitfield<15, 1> rg_rst_brg_; + platforms::darwinn::driver::Bitfield<16, 1> rg_gated_pcie_; + platforms::darwinn::driver::Bitfield<17, 1> rg_gated_phy_cfg_; + platforms::darwinn::driver::Bitfield<18, 2> rg_gated_gcb_; + platforms::darwinn::driver::Bitfield<20, 1> rg_gated_pcieslv_abm_; + platforms::darwinn::driver::Bitfield<21, 1> rg_gated_pciemst_abm_; + platforms::darwinn::driver::Bitfield<22, 1> rg_gated_omc_; + platforms::darwinn::driver::Bitfield<23, 1> rg_gated_mbist_; + platforms::darwinn::driver::Bitfield<24, 1> rg_gated_usb_; + platforms::darwinn::driver::Bitfield<25, 1> rg_gated_usb_subsys_; + platforms::darwinn::driver::Bitfield<26, 1> rg_gated_8051_; + } reg_; +}; + +// CSR helper to access fields for scu_ctrl_3 CSR. +class ScuCtrl3 { + public: + enum class GcbClock { + k63MHZ, + k125MHZ, + k250MHZ, + k500MHZ, + }; + enum class AxiClock { + k125MHZ, + k250MHZ, + }; + enum class Usb8051Clock { + k250MHZ, + k500MHZ, + }; + + // Defaults to reset value. + ScuCtrl3() : ScuCtrl3(0x80050410ULL) {} + explicit ScuCtrl3(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_force_sleep(uint64 value) { reg_.rg_force_sleep_ = value; } + uint64 rg_force_sleep() const { return reg_.rg_force_sleep_(); } + + void set_cur_pwr_state(uint64 value) { reg_.cur_pwr_state_ = value; } + uint64 cur_pwr_state() const { return reg_.cur_pwr_state_(); } + + void set_gcb_clock_rate(GcbClock rate) { + switch (rate) { + case GcbClock::k63MHZ: + reg_.rg_gcb_clkdiv_ = 3; + break; + case GcbClock::k125MHZ: + reg_.rg_gcb_clkdiv_ = 2; + break; + case GcbClock::k250MHZ: + reg_.rg_gcb_clkdiv_ = 1; + break; + case GcbClock::k500MHZ: + reg_.rg_gcb_clkdiv_ = 0; + break; + } + } + + GcbClock gcb_clock_rate() const { + switch (reg_.rg_gcb_clkdiv_()) { + case 3: + return GcbClock::k63MHZ; + case 2: + return GcbClock::k125MHZ; + case 1: + return GcbClock::k250MHZ; + default: + return GcbClock::k500MHZ; + } + } + + void set_axi_clock_rate(AxiClock rate) { + switch (rate) { + case AxiClock::k125MHZ: + reg_.rg_axi_clk_125m_ = 1; + break; + case AxiClock::k250MHZ: + reg_.rg_axi_clk_125m_ = 0; + break; + } + } + + AxiClock axi_clock_rate() const { + if (reg_.rg_axi_clk_125m_()) { + return AxiClock::k125MHZ; + } else { + return AxiClock::k250MHZ; + } + } + + void set_usb_8051_clock_rate(Usb8051Clock rate) { + switch (rate) { + case Usb8051Clock::k250MHZ: + reg_.rg_8051_clk_250m_ = 1; + break; + case Usb8051Clock::k500MHZ: + reg_.rg_8051_clk_250m_ = 0; + break; + } + } + + Usb8051Clock usb_8051_clock_rate() const { + if (reg_.rg_8051_clk_250m_()) { + return Usb8051Clock::k250MHZ; + } else { + return Usb8051Clock::k500MHZ; + } + } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 1> pcie_state_l1p2_; + platforms::darwinn::driver::Bitfield<1, 1> pcie_state_l0s_; + platforms::darwinn::driver::Bitfield<2, 1> pcie_state_l0_; + platforms::darwinn::driver::Bitfield<3, 1> cur_gated_gcb_; + platforms::darwinn::driver::Bitfield<4, 1> cur_rst_gcb_; + platforms::darwinn::driver::Bitfield<5, 1> field_05_; + platforms::darwinn::driver::Bitfield<6, 1> cur_gcb_sram_sd_; + platforms::darwinn::driver::Bitfield<7, 1> cur_gcb_sram_dslp_; + platforms::darwinn::driver::Bitfield<8, 2> cur_pwr_state_; + platforms::darwinn::driver::Bitfield<10, 2> pcie_gen_info_; + platforms::darwinn::driver::Bitfield<12, 2> rg_force_ram_dslp_; + platforms::darwinn::driver::Bitfield<14, 2> rg_force_ram_sd_; + platforms::darwinn::driver::Bitfield<16, 3> rg_sd2wk_dly_; + platforms::darwinn::driver::Bitfield<19, 1> rg_slp_mode_req_; + platforms::darwinn::driver::Bitfield<20, 2> rg_force_inact_; + platforms::darwinn::driver::Bitfield<22, 2> rg_force_sleep_; + platforms::darwinn::driver::Bitfield<24, 1> field_24_; + platforms::darwinn::driver::Bitfield<25, 1> rg_link_rdy_ovr_; + platforms::darwinn::driver::Bitfield<26, 2> rg_pwr_state_ovr_; + platforms::darwinn::driver::Bitfield<28, 2> rg_gcb_clkdiv_; + platforms::darwinn::driver::Bitfield<30, 1> rg_axi_clk_125m_; + platforms::darwinn::driver::Bitfield<31, 1> rg_8051_clk_250m_; + } reg_; +}; + +// CSR helper to access fields for scu_ctrl_6 CSR. +class ScuCtrl6 { + public: + // Defaults to reset value. + ScuCtrl6() : ScuCtrl6(0ULL) {} + explicit ScuCtrl6(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_gcb_spare_in(uint64 value) { reg_.rg_gcb_spare_in_ = value; } + uint64 rg_gcb_spare_in() const { return reg_.rg_gcb_spare_in_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 2> rg_pad_ds_; + platforms::darwinn::driver::Bitfield<2, 2> rg_pad_ds_i2c_; + platforms::darwinn::driver::Bitfield<4, 2> rg_pad_ds_gpio_; + platforms::darwinn::driver::Bitfield<6, 2> rg_pad_ds_xin_; + platforms::darwinn::driver::Bitfield<8, 8> rg_pinmux_sel_; + platforms::darwinn::driver::Bitfield<16, 4> rg_gcb_spare_in_; + platforms::darwinn::driver::Bitfield<20, 4> gcb_spare_out_; + platforms::darwinn::driver::Bitfield<24, 1> warning_o_; + platforms::darwinn::driver::Bitfield<25, 1> int_mbist_; + platforms::darwinn::driver::Bitfield<26, 1> err_resp_isr_0_; + platforms::darwinn::driver::Bitfield<27, 1> err_resp_isr_1_; + platforms::darwinn::driver::Bitfield<28, 2> rg_jtag_sel_; + platforms::darwinn::driver::Bitfield<30, 1> rg_jtag_io_sel_; + platforms::darwinn::driver::Bitfield<31, 1> field_31_; + } reg_; +}; + +// CSR helper to access fields for scu_ctr_7 CSR. +class ScuCtrl7 { + public: + // Defaults to reset value. + ScuCtrl7() : ScuCtrl7(0ULL) { + set_rg_inact_thd(0x3F); + set_rg_boot_failure_mask(0x3); + } + explicit ScuCtrl7(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_rg_boot_failure_mask(uint64 value) { + reg_.rg_boot_failure_mask_ = value; + } + uint64 rg_boot_failure_mask() const { return reg_.rg_boot_failure_mask_(); } + + void set_rg_inact_thd(uint64 value) { reg_.rg_inact_thd_ = value; } + uint64 rg_inact_thd() const { return reg_.rg_inact_thd_(); } + + void set_rg_boot_failure_raw(uint64 value) { + reg_.rg_boot_failure_raw_ = value; + } + uint64 rg_boot_failure_raw() const { return reg_.rg_boot_failure_raw_(); } + + void set_pll_lock_failure(uint64 value) { reg_.pll_lock_failure_ = value; } + uint64 pll_lock_failure() const { return reg_.pll_lock_failure_(); } + + void set_usb_sel_failure(uint64 value) { reg_.usb_sel_failure_ = value; } + uint64 usb_sel_failure() const { return reg_.usb_sel_failure_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 16> rg_inact_thd_; + platforms::darwinn::driver::Bitfield<16, 1> pll_lock_failure_; + platforms::darwinn::driver::Bitfield<17, 1> usb_sel_failure_; + platforms::darwinn::driver::Bitfield<18, 2> rg_boot_failure_mask_; + platforms::darwinn::driver::Bitfield<20, 2> rg_boot_failure_raw_; + platforms::darwinn::driver::Bitfield<22, 10> field_22_; + } reg_; +}; + +} // namespace registers +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_BEAGLE_CSR_HELPER_H_ diff --git a/driver/config/breakpoint_csr_offsets.h b/driver/config/breakpoint_csr_offsets.h new file mode 100644 index 0000000..2cb9cb4 --- /dev/null +++ b/driver/config/breakpoint_csr_offsets.h @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_BREAKPOINT_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_BREAKPOINT_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for breakpointing a hardware component. +// e.g. scalar core, TTU, etc. +// Members are intentionally named to match the GCSR register names. +struct BreakpointCsrOffsets { + uint64 RunControl; + uint64 RunStatus; + uint64 BreakPoint; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_BREAKPOINT_CSR_OFFSETS_H_ diff --git a/driver/config/cb_bridge_csr_offsets.h b/driver/config/cb_bridge_csr_offsets.h new file mode 100644 index 0000000..2f9ae0d --- /dev/null +++ b/driver/config/cb_bridge_csr_offsets.h @@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_CB_BRIDGE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_CB_BRIDGE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for cb_bridge in Beagle. +// Members are intentionally named to match the GCSR register names. +struct CbBridgeCsrOffsets { + uint64 bo0_fifo_status; + uint64 bo1_fifo_status; + uint64 bo2_fifo_status; + uint64 bo3_fifo_status; + uint64 gcbb_credit0; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_CB_BRIDGE_CSR_OFFSETS_H_ diff --git a/driver/config/chip_config.h b/driver/config/chip_config.h new file mode 100644 index 0000000..76f6a49 --- /dev/null +++ b/driver/config/chip_config.h @@ -0,0 +1,399 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_CHIP_CONFIG_H_ +#define DARWINN_DRIVER_CONFIG_CHIP_CONFIG_H_ + +#include "api/chip.h" +#include "driver/config/apex_csr_offsets.h" +#include "driver/config/breakpoint_csr_offsets.h" +#include "driver/config/cb_bridge_csr_offsets.h" +#include "driver/config/chip_structures.h" +#include "driver/config/debug_hib_user_csr_offsets.h" +#include "driver/config/debug_scalar_core_csr_offsets.h" +#include "driver/config/debug_tile_csr_offsets.h" +#include "driver/config/hib_kernel_csr_offsets.h" +#include "driver/config/hib_user_csr_offsets.h" +#include "driver/config/interrupt_csr_offsets.h" +#include "driver/config/memory_csr_offsets.h" +#include "driver/config/misc_csr_offsets.h" +#include "driver/config/msix_csr_offsets.h" +#include "driver/config/queue_csr_offsets.h" +#include "driver/config/register_file_csr_offsets.h" +#include "driver/config/scalar_core_csr_offsets.h" +#include "driver/config/scu_csr_offsets.h" +#include "driver/config/sync_flag_csr_offsets.h" +#include "driver/config/tile_config_csr_offsets.h" +#include "driver/config/tile_csr_offsets.h" +#include "driver/config/tile_thread_csr_offsets.h" +#include "driver/config/trace_csr_offsets.h" +#include "driver/config/usb_csr_offsets.h" +#include "driver/config/wire_csr_offsets.h" +#include "port/logging.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// Project-independent interface for CSR offsets and system constants. +class ChipConfig { + public: + virtual ~ChipConfig() = default; + + virtual api::Chip GetChip() const = 0; + + // Extracts CSR offsets for various modules in DarwiNN. + virtual const HibKernelCsrOffsets& GetHibKernelCsrOffsets() const = 0; + virtual const HibUserCsrOffsets& GetHibUserCsrOffsets() const = 0; + virtual const QueueCsrOffsets& GetInstructionQueueCsrOffsets() const = 0; + virtual const HibUserCsrOffsets& GetContextSpecificHibUserCsrOffsets( + int context_id) const = 0; + virtual const QueueCsrOffsets& GetContextSpecificInstructionQueueCsrOffsets( + int context_id) const = 0; + virtual const ScalarCoreCsrOffsets& GetScalarCoreCsrOffsets() const = 0; + virtual const TileConfigCsrOffsets& GetTileConfigCsrOffsets() const = 0; + virtual const TileCsrOffsets& GetTileCsrOffsets() const = 0; + virtual bool HasThreadCsrOffsets() const { return false; } + virtual const TileThreadCsrOffsets& GetTileThread0CsrOffsets() const { + LOG(FATAL) << "Tile thread 0 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread1CsrOffsets() const { + LOG(FATAL) << "Tile thread 1 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread2CsrOffsets() const { + LOG(FATAL) << "Tile thread 2 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread3CsrOffsets() const { + LOG(FATAL) << "Tile thread 3 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread4CsrOffsets() const { + LOG(FATAL) << "Tile thread 4 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread5CsrOffsets() const { + LOG(FATAL) << "Tile thread 5 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread6CsrOffsets() const { + LOG(FATAL) << "Tile thread 6 not supported."; + unreachable(); + } + virtual const TileThreadCsrOffsets& GetTileThread7CsrOffsets() const { + LOG(FATAL) << "Tile thread 7 not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetScalarCoreInterruptCsrOffsets() + const = 0; + virtual const InterruptCsrOffsets& GetTopLevelInterruptCsrOffsets() const = 0; + virtual const InterruptCsrOffsets& GetFatalErrorInterruptCsrOffsets() + const = 0; + virtual const InterruptCsrOffsets& + GetContextSpecificScalarCoreInterruptCsrOffsets(int context_id) const = 0; + virtual const InterruptCsrOffsets& + GetContextSpecificTopLevelInterruptCsrOffsets(int context_id) const = 0; + virtual const InterruptCsrOffsets& + GetContextSpecificFatalErrorInterruptCsrOffsets(int context_id) const = 0; + // Extracts CSR offsets that supports specific functionality in DarwiNN. + virtual const MsixCsrOffsets& GetMsixCsrOffsets() const { + LOG(FATAL) << "MSIX interrupt not supported."; + unreachable(); + } + virtual const WireCsrOffsets& GetWireCsrOffsets() const = 0; + virtual const WireCsrOffsets& GetContextSpecificWireCsrOffsets( + int context_id) const = 0; + virtual const MiscCsrOffsets& GetMiscCsrOffsets() const { + LOG(FATAL) << "Misc not supported."; + unreachable(); + } + + // Extracts chip-specific constants in DarwiNN. + virtual const ChipStructures& GetChipStructures() const = 0; + + // Extracts CSR offsets used by scalar core debugger in DarwiNN. + virtual const BreakpointCsrOffsets& GetScalarCoreBreakpointCsrOffsets() + const = 0; + virtual const BreakpointCsrOffsets& + GetScalarCoreActivationTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetScalarCoreInfeedTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetScalarCoreOutfeedTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetScalarCoreParameterTtuBreakpointCsrOffsets() const = 0; + + virtual const RegisterFileCsrOffsets& GetScalarRegisterFileCsrOffsets() + const = 0; + virtual const RegisterFileCsrOffsets& GetPredicateRegisterFileCsrOffsets() + const = 0; + + virtual const MemoryCsrOffsets& GetScalarCoreMemoryCsrOffsets() const = 0; + + // Extracts CSR offsets used by tile debugger in DarwiNN. + virtual const BreakpointCsrOffsets& GetTileOpTtuBreakpointCsrOffsets() + const = 0; + virtual const BreakpointCsrOffsets& + GetTileWideToNarrowTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetTileNarrowToWideTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetTileRingBusConsumer0TtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetTileRingBusConsumer1TtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& + GetTileRingBusProducerTtuBreakpointCsrOffsets() const = 0; + virtual const BreakpointCsrOffsets& GetTileMeshBus0TtuBreakpointCsrOffsets() + const = 0; + virtual const BreakpointCsrOffsets& GetTileMeshBus1TtuBreakpointCsrOffsets() + const = 0; + virtual const BreakpointCsrOffsets& GetTileMeshBus2TtuBreakpointCsrOffsets() + const = 0; + virtual const BreakpointCsrOffsets& GetTileMeshBus3TtuBreakpointCsrOffsets() + const = 0; + + virtual const MemoryCsrOffsets& GetTileMemoryCsrOffsets() const = 0; + + // Extracts CSR offsets used by scalar core performance tracing. + virtual const TraceCsrOffsets& GetScalarCoreActivationTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetScalarCoreInfeedTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetScalarCoreOutfeedTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetScalarCoreParameterTtuTraceCsrOffsets() + const = 0; + + // Extracts CSR offsets used by tile performance tracing. + virtual const TraceCsrOffsets& GetTileOpTtuTraceCsrOffsets() const = 0; + virtual const TraceCsrOffsets& GetTileWideToNarrowTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetTileNarrowToWideTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetTileRingBusConsumer0TtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetTileRingBusConsumer1TtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetTileRingBusProducerTtuTraceCsrOffsets() + const = 0; + virtual const TraceCsrOffsets& GetTileMeshBus0TtuTraceCsrOffsets() const = 0; + virtual const TraceCsrOffsets& GetTileMeshBus1TtuTraceCsrOffsets() const = 0; + virtual const TraceCsrOffsets& GetTileMeshBus2TtuTraceCsrOffsets() const = 0; + virtual const TraceCsrOffsets& GetTileMeshBus3TtuTraceCsrOffsets() const = 0; + + virtual const TraceCsrOffsets& + GetScalarCoreIrqCompletionBufferTraceCsrOffsets() const { + LOG(FATAL) << "Irq completion buffer trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide0TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 0 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide1TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 1 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide2TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 2 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide3TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 3 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide4TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 4 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide5TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 5 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide6TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 6 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToWide7TraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to wide 7 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow0TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 0 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow1TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 1 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow2TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 2 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow3TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 3 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow4TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 4 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow5TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 5 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow6TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 6 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaWideToNarrow7TraceCsrOffsets() + const { + LOG(FATAL) << "DMA wide to narrow 7 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileDmaNarrowToNarrowTraceCsrOffsets() + const { + LOG(FATAL) << "DMA narrow to narrow trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp0TraceCsrOffsets() const { + LOG(FATAL) << "Op0 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp1TraceCsrOffsets() const { + LOG(FATAL) << "Op1 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp2TraceCsrOffsets() const { + LOG(FATAL) << "Op2 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp3TraceCsrOffsets() const { + LOG(FATAL) << "Op3 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp4TraceCsrOffsets() const { + LOG(FATAL) << "Op4 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp5TraceCsrOffsets() const { + LOG(FATAL) << "Op5 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp6TraceCsrOffsets() const { + LOG(FATAL) << "Op6 trace not supported."; + unreachable(); + } + virtual const TraceCsrOffsets& GetTileOp7TraceCsrOffsets() const { + LOG(FATAL) << "Op7 trace not supported."; + unreachable(); + } + + // Extracts CSR offsets used to access sync flags in scalar core. + virtual const SyncFlagCsrOffsets& GetScalarCoreAvdataPopSyncFlagCsrOffsets() + const = 0; + virtual const SyncFlagCsrOffsets& + GetScalarCoreParameterPopSyncFlagCsrOffsets() const = 0; + virtual const SyncFlagCsrOffsets& + GetScalarCoreAvdataInfeedSyncFlagCsrOffsets() const = 0; + virtual const SyncFlagCsrOffsets& + GetScalarCoreParameterInfeedSyncFlagCsrOffsets() const = 0; + virtual const SyncFlagCsrOffsets& + GetScalarCoreScalarInfeedSyncFlagCsrOffsets() const = 0; + virtual const SyncFlagCsrOffsets& GetScalarCoreProducerASyncFlagCsrOffsets() + const = 0; + virtual const SyncFlagCsrOffsets& GetScalarCoreProducerBSyncFlagCsrOffsets() + const = 0; + virtual const SyncFlagCsrOffsets& GetScalarCoreRingOutfeedSyncFlagCsrOffsets() + const = 0; + virtual const SyncFlagCsrOffsets& + GetScalarCoreScalarPipelineSyncFlagCsrOffsets() const = 0; + + // Extracts CSR offsets used by bug report generator in DarwiNN. + virtual const DebugHibUserCsrOffsets& GetDebugHibUserCsrOffsets() const = 0; + virtual const DebugHibUserCsrOffsets& + GetContextSpecificDebugHibUserCsrOffsets(int context_id) const = 0; + virtual const DebugScalarCoreCsrOffsets& GetDebugScalarCoreCsrOffsets() + const = 0; + virtual const DebugTileCsrOffsets& GetDebugTileCsrOffsets() const = 0; + + // Beagle-specific. + virtual const ApexCsrOffsets& GetApexCsrOffsets() const { + LOG(FATAL) << "Apex not supported."; + unreachable(); + } + virtual const ScuCsrOffsets& GetScuCsrOffsets() const { + LOG(FATAL) << "SCU not supported."; + unreachable(); + } + virtual const CbBridgeCsrOffsets& GetCbBridgeCsrOffsets() const { + LOG(FATAL) << "CB bridge not supported."; + unreachable(); + } + virtual const UsbCsrOffsets& GetUsbCsrOffsets() const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetUsbFatalErrorInterruptCsrOffsets() + const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetUsbTopLevel0InterruptCsrOffsets() + const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetUsbTopLevel1InterruptCsrOffsets() + const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetUsbTopLevel2InterruptCsrOffsets() + const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } + virtual const InterruptCsrOffsets& GetUsbTopLevel3InterruptCsrOffsets() + const { + LOG(FATAL) << "USB not supported."; + unreachable(); + } +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_CHIP_CONFIG_H_ diff --git a/driver/config/chip_structures.h b/driver/config/chip_structures.h new file mode 100644 index 0000000..5f2a58b --- /dev/null +++ b/driver/config/chip_structures.h @@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_CHIP_STRUCTURES_H_ +#define DARWINN_DRIVER_CONFIG_CHIP_STRUCTURES_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +struct ChipStructures { + // Hardware required minimum alignment on buffers. + uint64 minimum_alignment_bytes; + + // Buffer allocation alignment and granularity. Typically this would be same + // as minimum_alignment_bytes above, however may also factor in other + // requirements such as host cache line size, cache API constraints etc. + uint64 allocation_alignment_bytes; + + // Controls AXI burst length. + uint64 axi_dma_burst_limiter; + + // Number of wire interrupts. + uint64 num_wire_interrupts; + + // Number of page table entries. + uint64 num_page_table_entries; + + // Number of physical address bits generated by the hardware. + uint64 physical_address_bits; + + // Addressable byte size of TPU DRAM (if any). This must be divisible by host + // table size. + uint64 tpu_dram_size_bytes; + + // Total size of narrow memory per tile in bytes. + uint64 narrow_memory_capacity; + + // Size of address translation entry for external narrow memory interface. + uint64 external_narrow_memory_translate_entry_size_bytes; + + // Number of X tiles. + uint64 number_x_tiles; + + // Number of Y tiles. + uint64 number_y_tiles; + + // Number of compute threads. + uint64 number_compute_threads; + + // Number of virtual networks. + uint64 number_of_ring_virtual_networks; + + uint64 last_z_out_cell_disable_incompatible_with_sparsity; + + uint64 nlu_buffer_backpressure_causes_assertion; + + // Mesh queue depth. + uint64 mesh_rx_queue_depth; + + // Default VN buffer size. + uint64 default_vn_buffer_memory_lines; + + // Base offset for CSR + uint64 csr_region_base_offset; + + // Size CSR Region + uint64 csr_region_size_bytes; + + // Support trace architectural registers. + uint64 support_trace_arch_registers; + + // Shared Memory's base-and-bound unit size in bytes. This value is needed + // when determining how to program shared memory base and size for parameter, + // instruction, activation, and scalar memory partitions. + uint64 base_and_bound_unit_size_bytes; + + // Number of scalar core contexts supported. Default 1 is legacy behavior with + // no context switching. + uint64 number_of_scalar_core_contexts; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_CHIP_STRUCTURES_H_ diff --git a/driver/config/common_csr_helper.h b/driver/config/common_csr_helper.h new file mode 100644 index 0000000..f070b33 --- /dev/null +++ b/driver/config/common_csr_helper.h @@ -0,0 +1,623 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_COMMON_CSR_HELPER_H_ +#define DARWINN_DRIVER_CONFIG_COMMON_CSR_HELPER_H_ + +#include "driver/bitfield.h" +#include "port/integral_types.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { +namespace registers { + +// CSR helper to access fields for HibError* CSRs. +class HibError { + public: + HibError() : HibError(0ULL) {} + explicit HibError(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_inbound_page_fault(uint64 value) { + reg_.inbound_page_fault_ = value; + } + uint64 inbound_page_fault() const { return reg_.inbound_page_fault_(); } + void set_extended_page_fault(uint64 value) { + reg_.extended_page_fault_ = value; + } + uint64 extended_page_fault() const { return reg_.extended_page_fault_(); } + void set_csr_parity_error(uint64 value) { reg_.csr_parity_error_ = value; } + uint64 csr_parity_error() const { return reg_.csr_parity_error_(); } + void set_axi_slave_b_error(uint64 value) { reg_.axi_slave_b_error_ = value; } + uint64 axi_slave_b_error() const { return reg_.axi_slave_b_error_(); } + void set_axi_slave_r_error(uint64 value) { reg_.axi_slave_r_error_ = value; } + uint64 axi_slave_r_error() const { return reg_.axi_slave_r_error_(); } + void set_instruction_queue_bad_configuration(uint64 value) { + reg_.instruction_queue_bad_configuration_ = value; + } + uint64 instruction_queue_bad_configuration() const { + return reg_.instruction_queue_bad_configuration_(); + } + void set_input_actv_queue_bad_configuration(uint64 value) { + reg_.input_actv_queue_bad_configuration_ = value; + } + uint64 input_actv_queue_bad_configuration() const { + return reg_.input_actv_queue_bad_configuration_(); + } + void set_param_queue_bad_configuration(uint64 value) { + reg_.param_queue_bad_configuration_ = value; + } + uint64 param_queue_bad_configuration() const { + return reg_.param_queue_bad_configuration_(); + } + void set_output_actv_queue_bad_configuration(uint64 value) { + reg_.output_actv_queue_bad_configuration_ = value; + } + uint64 output_actv_queue_bad_configuration() const { + return reg_.output_actv_queue_bad_configuration_(); + } + void set_instruction_queue_invalid(uint64 value) { + reg_.instruction_queue_invalid_ = value; + } + uint64 instruction_queue_invalid() const { + return reg_.instruction_queue_invalid_(); + } + void set_input_actv_queue_invalid(uint64 value) { + reg_.input_actv_queue_invalid_ = value; + } + uint64 input_actv_queue_invalid() const { + return reg_.input_actv_queue_invalid_(); + } + void set_param_queue_invalid(uint64 value) { + reg_.param_queue_invalid_ = value; + } + uint64 param_queue_invalid() const { return reg_.param_queue_invalid_(); } + void set_output_actv_queue_invalid(uint64 value) { + reg_.output_actv_queue_invalid_ = value; + } + uint64 output_actv_queue_invalid() const { + return reg_.output_actv_queue_invalid_(); + } + void set_length_0_dma(uint64 value) { reg_.length_0_dma_ = value; } + uint64 length_0_dma() const { return reg_.length_0_dma_(); } + void set_virt_table_rdata_uncorr(uint64 value) { + reg_.virt_table_rdata_uncorr_ = value; + } + uint64 virt_table_rdata_uncorr() const { + return reg_.virt_table_rdata_uncorr_(); + } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> inbound_page_fault_; + platforms::darwinn::driver::Bitfield<1, 1> extended_page_fault_; + platforms::darwinn::driver::Bitfield<2, 1> csr_parity_error_; + platforms::darwinn::driver::Bitfield<3, 1> axi_slave_b_error_; + platforms::darwinn::driver::Bitfield<4, 1> axi_slave_r_error_; + platforms::darwinn::driver::Bitfield<5, 1> + instruction_queue_bad_configuration_; + platforms::darwinn::driver::Bitfield<6, 1> + input_actv_queue_bad_configuration_; + platforms::darwinn::driver::Bitfield<7, 1> param_queue_bad_configuration_; + platforms::darwinn::driver::Bitfield<8, 1> + output_actv_queue_bad_configuration_; + platforms::darwinn::driver::Bitfield<9, 1> instruction_queue_invalid_; + platforms::darwinn::driver::Bitfield<10, 1> input_actv_queue_invalid_; + platforms::darwinn::driver::Bitfield<11, 1> param_queue_invalid_; + platforms::darwinn::driver::Bitfield<12, 1> output_actv_queue_invalid_; + platforms::darwinn::driver::Bitfield<13, 1> length_0_dma_; + platforms::darwinn::driver::Bitfield<14, 1> virt_table_rdata_uncorr_; + } reg_; +}; + +// CSR helper to access fields for *QueueControl CSR. +class QueueControl { + public: + QueueControl() : QueueControl(0ULL) {} + explicit QueueControl(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_enable(uint64 value) { reg_.enable_ = value; } + uint64 enable() const { return reg_.enable_(); } + void set_sc_desc_select(uint64 value) { reg_.sc_desc_select_ = value; } + uint64 sc_desc_select() const { return reg_.sc_desc_select_(); } + void set_sb_wr_enable(uint64 value) { reg_.sb_wr_enable_ = value; } + uint64 sb_wr_enable() const { return reg_.sb_wr_enable_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> enable_; + platforms::darwinn::driver::Bitfield<1, 1> sc_desc_select_; + platforms::darwinn::driver::Bitfield<2, 1> sb_wr_enable_; + } reg_; +}; + +// CSR helper to access fields for ScHostIntCount CSR. +class ScHostIntCount { + public: + ScHostIntCount() : ScHostIntCount(0ULL) {} + explicit ScHostIntCount(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_cnt0(uint64 value) { reg_.cnt0_ = value; } + uint64 cnt0() const { return reg_.cnt0_(); } + void set_cnt1(uint64 value) { reg_.cnt1_ = value; } + uint64 cnt1() const { return reg_.cnt1_(); } + void set_cnt2(uint64 value) { reg_.cnt2_ = value; } + uint64 cnt2() const { return reg_.cnt2_(); } + void set_cnt3(uint64 value) { reg_.cnt3_ = value; } + uint64 cnt3() const { return reg_.cnt3_(); } + + // Sets |index|-th field from LSB to |value|. + void set_field(int index, uint64 value) { + switch (index) { + case 0: + reg_.cnt0_ = value; + break; + + case 1: + reg_.cnt1_ = value; + break; + + case 2: + reg_.cnt2_ = value; + break; + + case 3: + reg_.cnt3_ = value; + break; + + default: + CHECK(false) << "Unknown field index: " << index; + break; + } + } + + // Returns |index|-th field from LSB. + uint64 get_field(int index) { + switch (index) { + case 0: + return reg_.cnt0_(); + + case 1: + return reg_.cnt1_(); + + case 2: + return reg_.cnt2_(); + + case 3: + return reg_.cnt3_(); + + default: + LOG(FATAL) << "Unknown field index: " << index; + unreachable(); + } + } + + // Returns masked |value| for |index|-th field from LSB. + uint64 mask_field(int index, uint64 value) { + switch (index) { + case 0: + return value & reg_.cnt0_.mask(); + + case 1: + return value & reg_.cnt1_.mask(); + + case 2: + return value & reg_.cnt2_.mask(); + + case 3: + return value & reg_.cnt3_.mask(); + + default: + LOG(FATAL) << "Unknown field index: " << index; + unreachable(); + } + } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 16> cnt0_; + platforms::darwinn::driver::Bitfield<16, 16> cnt1_; + platforms::darwinn::driver::Bitfield<32, 16> cnt2_; + platforms::darwinn::driver::Bitfield<48, 16> cnt3_; + } reg_; +}; + +// CSR helper to access fields for ScHostIntStatus CSR. +class ScHostIntStatus { + public: + ScHostIntStatus() : ScHostIntStatus(0ULL) {} + explicit ScHostIntStatus(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_hot0(uint64 value) { reg_.hot0_ = value; } + uint64 hot0() const { return reg_.hot0_(); } + void set_hot1(uint64 value) { reg_.hot1_ = value; } + uint64 hot1() const { return reg_.hot1_(); } + void set_hot2(uint64 value) { reg_.hot2_ = value; } + uint64 hot2() const { return reg_.hot2_(); } + void set_hot3(uint64 value) { reg_.hot3_ = value; } + uint64 hot3() const { return reg_.hot3_(); } + + // Sets |index|-th field from LSB to |value|. + void set_field(int index, uint64 value) { + switch (index) { + case 0: + reg_.hot0_ = value; + break; + + case 1: + reg_.hot1_ = value; + break; + + case 2: + reg_.hot2_ = value; + break; + + case 3: + reg_.hot3_ = value; + break; + + default: + CHECK(false) << "Unknown field index: " << index; + break; + } + } + + // Returns |index|-th field from LSB. + uint64 get_field(int index) { + switch (index) { + case 0: + return reg_.hot0_(); + + case 1: + return reg_.hot1_(); + + case 2: + return reg_.hot2_(); + + case 3: + return reg_.hot3_(); + + default: + LOG(FATAL) << "Unknown field index: " << index; + unreachable(); + } + } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> hot0_; + platforms::darwinn::driver::Bitfield<1, 1> hot1_; + platforms::darwinn::driver::Bitfield<2, 1> hot2_; + platforms::darwinn::driver::Bitfield<3, 1> hot3_; + } reg_; +}; + +// CSR helper to access fields for ScHostIntVector CSR. +class ScHostIntVector { + public: + ScHostIntVector() : ScHostIntVector(0ULL) {} + explicit ScHostIntVector(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_vector0(uint64 value) { reg_.vector0_ = value; } + uint64 vector0() const { return reg_.vector0_(); } + void set_vector1(uint64 value) { reg_.vector1_ = value; } + uint64 vector1() const { return reg_.vector1_(); } + void set_vector2(uint64 value) { reg_.vector2_ = value; } + uint64 vector2() const { return reg_.vector2_(); } + void set_vector3(uint64 value) { reg_.vector3_ = value; } + uint64 vector3() const { return reg_.vector3_(); } + + // Sets |index|-th field from LSB to |value|. + void set_field(int index, uint64 value) { + switch (index) { + case 0: + reg_.vector0_ = value; + break; + + case 1: + reg_.vector1_ = value; + break; + + case 2: + reg_.vector2_ = value; + break; + + case 3: + reg_.vector3_ = value; + break; + + default: + CHECK(false) << "Unknown field index: " << index; + break; + } + } + + // Returns |index|-th field from LSB. + uint64 get_field(int index) { + switch (index) { + case 0: + return reg_.vector0_(); + + case 1: + return reg_.vector1_(); + + case 2: + return reg_.vector2_(); + + case 3: + return reg_.vector3_(); + + default: + LOG(FATAL) << "Unknown field index: " << index; + unreachable(); + } + } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 7> vector0_; + platforms::darwinn::driver::Bitfield<7, 7> vector1_; + platforms::darwinn::driver::Bitfield<14, 7> vector2_; + platforms::darwinn::driver::Bitfield<21, 7> vector3_; + } reg_; +}; + +// CSR helper to access fields for WireIntPendingBitArray and +// WireIntMaskArray CSR. +class WireIntBitArray { + public: + WireIntBitArray() : WireIntBitArray(0ULL) {} + explicit WireIntBitArray(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_instruction_queue(uint64 value) { reg_.instruction_queue_ = value; } + uint64 instruction_queue() const { return reg_.instruction_queue_(); } + void set_input_actv_queue(uint64 value) { reg_.input_actv_queue_ = value; } + uint64 input_actv_queue() const { return reg_.input_actv_queue_(); } + void set_param_queue(uint64 value) { reg_.param_queue_ = value; } + uint64 param_queue() const { return reg_.param_queue_(); } + void set_output_actv_queue(uint64 value) { reg_.output_actv_queue_ = value; } + uint64 output_actv_queue() const { return reg_.output_actv_queue_(); } + void set_sc_host_0(uint64 value) { reg_.sc_host_0_ = value; } + uint64 sc_host_0() const { return reg_.sc_host_0_(); } + void set_sc_host_1(uint64 value) { reg_.sc_host_1_ = value; } + uint64 sc_host_1() const { return reg_.sc_host_1_(); } + void set_sc_host_2(uint64 value) { reg_.sc_host_2_ = value; } + uint64 sc_host_2() const { return reg_.sc_host_2_(); } + void set_sc_host_3(uint64 value) { reg_.sc_host_3_ = value; } + uint64 sc_host_3() const { return reg_.sc_host_3_(); } + void set_top_level_0(uint64 value) { reg_.top_level_0_ = value; } + uint64 top_level_0() const { return reg_.top_level_0_(); } + void set_top_level_1(uint64 value) { reg_.top_level_1_ = value; } + uint64 top_level_1() const { return reg_.top_level_1_(); } + void set_top_level_2(uint64 value) { reg_.top_level_2_ = value; } + uint64 top_level_2() const { return reg_.top_level_2_(); } + void set_top_level_3(uint64 value) { reg_.top_level_3_ = value; } + uint64 top_level_3() const { return reg_.top_level_3_(); } + void set_fatal_err(uint64 value) { reg_.fatal_err_ = value; } + uint64 fatal_err() const { return reg_.fatal_err_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> instruction_queue_; + platforms::darwinn::driver::Bitfield<1, 1> input_actv_queue_; + platforms::darwinn::driver::Bitfield<2, 1> param_queue_; + platforms::darwinn::driver::Bitfield<3, 1> output_actv_queue_; + platforms::darwinn::driver::Bitfield<4, 1> sc_host_0_; + platforms::darwinn::driver::Bitfield<5, 1> sc_host_1_; + platforms::darwinn::driver::Bitfield<6, 1> sc_host_2_; + platforms::darwinn::driver::Bitfield<7, 1> sc_host_3_; + platforms::darwinn::driver::Bitfield<8, 1> top_level_0_; + platforms::darwinn::driver::Bitfield<9, 1> top_level_1_; + platforms::darwinn::driver::Bitfield<10, 1> top_level_2_; + platforms::darwinn::driver::Bitfield<11, 1> top_level_3_; + platforms::darwinn::driver::Bitfield<12, 1> fatal_err_; + } reg_; +}; + +// Interface to access fields for tile configs. +class TileConfigInterface { + public: + virtual ~TileConfigInterface() = default; + + // Access to aggregated value. + virtual void set_raw(uint64 value) = 0; + virtual uint64 raw() const = 0; + + // Sets tile id. + virtual void set_broadcast() = 0; + virtual void set_tile(uint64 value) = 0; + + // Returns tile field. + virtual uint64 tile() const = 0; +}; + +// Implements TileConfigInterface with given TILE_BITS. +template +class TileConfig : public TileConfigInterface { + public: + TileConfig() : TileConfig(0ULL) {} + explicit TileConfig(uint64 value) { reg_.raw_ = value; } + ~TileConfig() = default; + + void set_raw(uint64 value) override { reg_.raw_ = value; } + uint64 raw() const override { return reg_.raw_; } + + void set_broadcast() override { + reg_.tile_ = static_cast(-1) & reg_.tile_.mask(); + } + void set_tile(uint64 value) { reg_.tile_ = value; } + uint64 tile() const override { return reg_.tile_(); } + + private: + union { + // Entire CSR value. + uint64 raw_; + // Tile id field. + platforms::darwinn::driver::Bitfield<0, TILE_BITS> tile_; + } reg_; +}; + +// CSR helper to access fields for clockEnableReg CSR. +class ClockEnableReg { + public: + ClockEnableReg() : ClockEnableReg(0ULL) {} + explicit ClockEnableReg(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + void set_clock_enable(uint64 value) { reg_.clock_enable_ = value; } + uint64 clock_enable() const { return reg_.clock_enable_(); } + void set_idle_override(uint64 value) { reg_.idle_override_ = value; } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> clock_enable_; + platforms::darwinn::driver::Bitfield<1, 1> idle_override_; + } reg_; +}; + +// CSR helper to access fields for idleRegister CSR. +class IdleRegister { + public: + // Defaults to reset value. + IdleRegister() : IdleRegister(0x00009000ULL) {} + explicit IdleRegister(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_enable() { reg_.disable_idle_ = 0; } + void set_disable() { reg_.disable_idle_ = 1; } + void set_counter(uint64 value) { reg_.counter_ = value; } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 31> counter_; + platforms::darwinn::driver::Bitfield<31, 1> disable_idle_; + } reg_; +}; + +// CSR helper to access fields for logicShutdownPreReg/logicShutdownAllReg. +template +class ShutdownReg { + public: + // Defaults to reset value. + ShutdownReg() : ShutdownReg(0x0) { + set_logic_shutdown((1ULL << NUM_BITS) - 1); + } + explicit ShutdownReg(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_logic_shutdown(uint64 value) { reg_.logic_shutdown_ = value; } + void set_logic_shutdown_ack(uint64 value) { + reg_.logic_shutdown_ack_ = value; + } + uint64 logic_shutdown() const { return reg_.logic_shutdown_(); } + uint64 logic_shutdown_ack() const { return reg_.logic_shutdown_ack_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, NUM_BITS> logic_shutdown_; + platforms::darwinn::driver::Bitfield + logic_shutdown_ack_; + } reg_; +}; + +// CSR helper to access fields for deepSleep. +class DeepSleep { + public: + // Defaults to reset value. + DeepSleep() : DeepSleep(0x0) {} + explicit DeepSleep(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_to_sleep_delay(uint64 value) { reg_.to_sleep_delay_ = value; } + void set_to_wake_delay(uint64 value) { reg_.to_wake_delay_ = value; } + uint64 narrow_mem_deep_sleep() const { return reg_.narrow_mem_deep_sleep_(); } + uint64 wide_mem_deep_sleep() const { return reg_.wide_mem_deep_sleep_(); } + + private: + union { + uint64 raw_; + // These are named after fields in the spec. + platforms::darwinn::driver::Bitfield<0, 8> to_sleep_delay_; + platforms::darwinn::driver::Bitfield<8, 8> to_wake_delay_; + platforms::darwinn::driver::Bitfield<16, 1> narrow_mem_deep_sleep_; + platforms::darwinn::driver::Bitfield<17, 1> wide_mem_deep_sleep_; + } reg_; +}; + +// Implements field level access for SharedMemoryInitControl CSR. +class SharedMemoryInitControl { + public: + SharedMemoryInitControl() : SharedMemoryInitControl(/*value=*/0ULL) {} + SharedMemoryInitControl(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_trigger(uint64 value) { reg_.trigger_ = value; } + uint64 trigger() const { return reg_.trigger_(); } + + void set_run(uint64 value) { reg_.run_ = value; } + uint64 run() const { return reg_.run_(); } + + void set_done(uint64 value) { reg_.done_ = value; } + uint64 done() const { return reg_.done_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 1> trigger_; + platforms::darwinn::driver::Bitfield<1, 1> run_; + platforms::darwinn::driver::Bitfield<2, 1> done_; + } reg_; +}; + +} // namespace registers +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_COMMON_CSR_HELPER_H_ diff --git a/driver/config/debug_hib_user_csr_offsets.h b/driver/config/debug_hib_user_csr_offsets.h new file mode 100644 index 0000000..b5c62ed --- /dev/null +++ b/driver/config/debug_hib_user_csr_offsets.h @@ -0,0 +1,276 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_DEBUG_HIB_USER_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_DEBUG_HIB_USER_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets that will be dumped as part of the +// driver bug report for user hib. Members are intentionally named to match the +// GCSR register names. +struct DebugHibUserCsrOffsets { + uint64 instruction_inbound_queue_total_occupancy; + uint64 instruction_inbound_queue_threshold_counter; + uint64 instruction_inbound_queue_insertion_counter; + uint64 instruction_inbound_queue_full_counter; + uint64 input_actv_inbound_queue_total_occupancy; + uint64 input_actv_inbound_queue_threshold_counter; + uint64 input_actv_inbound_queue_insertion_counter; + uint64 input_actv_inbound_queue_full_counter; + uint64 param_inbound_queue_total_occupancy; + uint64 param_inbound_queue_threshold_counter; + uint64 param_inbound_queue_insertion_counter; + uint64 param_inbound_queue_full_counter; + uint64 output_actv_inbound_queue_total_occupancy; + uint64 output_actv_inbound_queue_threshold_counter; + uint64 output_actv_inbound_queue_insertion_counter; + uint64 output_actv_inbound_queue_full_counter; + uint64 status_block_write_inbound_queue_total_occupancy; + uint64 status_block_write_inbound_queue_threshold_counter; + uint64 status_block_write_inbound_queue_insertion_counter; + uint64 status_block_write_inbound_queue_full_counter; + uint64 queue_fetch_inbound_queue_total_occupancy; + uint64 queue_fetch_inbound_queue_threshold_counter; + uint64 queue_fetch_inbound_queue_insertion_counter; + uint64 queue_fetch_inbound_queue_full_counter; + uint64 instruction_outbound_queue_total_occupancy; + uint64 instruction_outbound_queue_threshold_counter; + uint64 instruction_outbound_queue_insertion_counter; + uint64 instruction_outbound_queue_full_counter; + uint64 input_actv_outbound_queue_total_occupancy; + uint64 input_actv_outbound_queue_threshold_counter; + uint64 input_actv_outbound_queue_insertion_counter; + uint64 input_actv_outbound_queue_full_counter; + uint64 param_outbound_queue_total_occupancy; + uint64 param_outbound_queue_threshold_counter; + uint64 param_outbound_queue_insertion_counter; + uint64 param_outbound_queue_full_counter; + uint64 output_actv_outbound_queue_total_occupancy; + uint64 output_actv_outbound_queue_threshold_counter; + uint64 output_actv_outbound_queue_insertion_counter; + uint64 output_actv_outbound_queue_full_counter; + uint64 status_block_write_outbound_queue_total_occupancy; + uint64 status_block_write_outbound_queue_threshold_counter; + uint64 status_block_write_outbound_queue_insertion_counter; + uint64 status_block_write_outbound_queue_full_counter; + uint64 queue_fetch_outbound_queue_total_occupancy; + uint64 queue_fetch_outbound_queue_threshold_counter; + uint64 queue_fetch_outbound_queue_insertion_counter; + uint64 queue_fetch_outbound_queue_full_counter; + uint64 page_table_request_outbound_queue_total_occupancy; + uint64 page_table_request_outbound_queue_threshold_counter; + uint64 page_table_request_outbound_queue_insertion_counter; + uint64 page_table_request_outbound_queue_full_counter; + uint64 read_tracking_fifo_total_occupancy; + uint64 read_tracking_fifo_threshold_counter; + uint64 read_tracking_fifo_insertion_counter; + uint64 read_tracking_fifo_full_counter; + uint64 write_tracking_fifo_total_occupancy; + uint64 write_tracking_fifo_threshold_counter; + uint64 write_tracking_fifo_insertion_counter; + uint64 write_tracking_fifo_full_counter; + uint64 read_buffer_total_occupancy; + uint64 read_buffer_threshold_counter; + uint64 read_buffer_insertion_counter; + uint64 read_buffer_full_counter; + uint64 axi_aw_credit_shim_total_occupancy; + uint64 axi_aw_credit_shim_threshold_counter; + uint64 axi_aw_credit_shim_insertion_counter; + uint64 axi_aw_credit_shim_full_counter; + uint64 axi_ar_credit_shim_total_occupancy; + uint64 axi_ar_credit_shim_threshold_counter; + uint64 axi_ar_credit_shim_insertion_counter; + uint64 axi_ar_credit_shim_full_counter; + uint64 axi_w_credit_shim_total_occupancy; + uint64 axi_w_credit_shim_threshold_counter; + uint64 axi_w_credit_shim_insertion_counter; + uint64 axi_w_credit_shim_full_counter; + uint64 instruction_inbound_queue_empty_cycles_count; + uint64 input_actv_inbound_queue_empty_cycles_count; + uint64 param_inbound_queue_empty_cycles_count; + uint64 output_actv_inbound_queue_empty_cycles_count; + uint64 status_block_write_inbound_queue_empty_cycles_count; + uint64 queue_fetch_inbound_queue_empty_cycles_count; + uint64 instruction_outbound_queue_empty_cycles_count; + uint64 input_actv_outbound_queue_empty_cycles_count; + uint64 param_outbound_queue_empty_cycles_count; + uint64 output_actv_outbound_queue_empty_cycles_count; + uint64 status_block_write_outbound_queue_empty_cycles_count; + uint64 queue_fetch_outbound_queue_empty_cycles_count; + uint64 page_table_request_outbound_queue_empty_cycles_count; + uint64 read_tracking_fifo_empty_cycles_count; + uint64 write_tracking_fifo_empty_cycles_count; + uint64 read_buffer_empty_cycles_count; + uint64 read_request_arbiter_instruction_request_cycles; + uint64 read_request_arbiter_instruction_blocked_cycles; + uint64 read_request_arbiter_instruction_blocked_by_arbitration_cycles; + uint64 read_request_arbiter_instruction_cycles_blocked_over_threshold; + uint64 read_request_arbiter_input_actv_request_cycles; + uint64 read_request_arbiter_input_actv_blocked_cycles; + uint64 read_request_arbiter_input_actv_blocked_by_arbitration_cycles; + uint64 read_request_arbiter_input_actv_cycles_blocked_over_threshold; + uint64 read_request_arbiter_param_request_cycles; + uint64 read_request_arbiter_param_blocked_cycles; + uint64 read_request_arbiter_param_blocked_by_arbitration_cycles; + uint64 read_request_arbiter_param_cycles_blocked_over_threshold; + uint64 read_request_arbiter_queue_fetch_request_cycles; + uint64 read_request_arbiter_queue_fetch_blocked_cycles; + uint64 read_request_arbiter_queue_fetch_blocked_by_arbitration_cycles; + uint64 read_request_arbiter_queue_fetch_cycles_blocked_over_threshold; + uint64 read_request_arbiter_page_table_request_request_cycles; + uint64 read_request_arbiter_page_table_request_blocked_cycles; + uint64 read_request_arbiter_page_table_request_blocked_by_arbitration_cycles; + uint64 read_request_arbiter_page_table_request_cycles_blocked_over_threshold; + uint64 write_request_arbiter_output_actv_request_cycles; + uint64 write_request_arbiter_output_actv_blocked_cycles; + uint64 write_request_arbiter_output_actv_blocked_by_arbitration_cycles; + uint64 write_request_arbiter_output_actv_cycles_blocked_over_threshold; + uint64 write_request_arbiter_status_block_write_request_cycles; + uint64 write_request_arbiter_status_block_write_blocked_cycles; + uint64 write_request_arbiter_status_block_write_blocked_by_arbitration_cycles; + uint64 write_request_arbiter_status_block_write_cycles_blocked_over_threshold; + uint64 address_translation_arbiter_instruction_request_cycles; + uint64 address_translation_arbiter_instruction_blocked_cycles; + uint64 address_translation_arbiter_instruction_blocked_by_arbitration_cycles; + uint64 address_translation_arbiter_instruction_cycles_blocked_over_threshold; + uint64 address_translation_arbiter_input_actv_request_cycles; + uint64 address_translation_arbiter_input_actv_blocked_cycles; + uint64 address_translation_arbiter_input_actv_blocked_by_arbitration_cycles; + uint64 address_translation_arbiter_input_actv_cycles_blocked_over_threshold; + uint64 address_translation_arbiter_param_request_cycles; + uint64 address_translation_arbiter_param_blocked_cycles; + uint64 address_translation_arbiter_param_blocked_by_arbitration_cycles; + uint64 address_translation_arbiter_param_cycles_blocked_over_threshold; + uint64 address_translation_arbiter_status_block_write_request_cycles; + uint64 address_translation_arbiter_status_block_write_blocked_cycles; + uint64 + address_translation_arbiter_status_block_write_blocked_by_arbitration_cycles; // NOLINT + uint64 + address_translation_arbiter_status_block_write_cycles_blocked_over_threshold; // NOLINT + uint64 address_translation_arbiter_output_actv_request_cycles; + uint64 address_translation_arbiter_output_actv_blocked_cycles; + uint64 address_translation_arbiter_output_actv_blocked_by_arbitration_cycles; + uint64 address_translation_arbiter_output_actv_cycles_blocked_over_threshold; + uint64 address_translation_arbiter_queue_fetch_request_cycles; + uint64 address_translation_arbiter_queue_fetch_blocked_cycles; + uint64 address_translation_arbiter_queue_fetch_blocked_by_arbitration_cycles; + uint64 address_translation_arbiter_queue_fetch_cycles_blocked_over_threshold; + uint64 issued_interrupt_count; + uint64 data_read_16byte_count; + uint64 waiting_for_tag_cycles; + uint64 waiting_for_axi_cycles; + uint64 simple_translations; + uint64 instruction_credits_per_cycle_sum; + uint64 input_actv_credits_per_cycle_sum; + uint64 param_credits_per_cycle_sum; + uint64 output_actv_credits_per_cycle_sum; + uint64 status_block_write_credits_per_cycle_sum; + uint64 queue_fetch_credits_per_cycle_sum; + uint64 page_table_request_credits_per_cycle_sum; + uint64 output_actv_queue_control; + uint64 output_actv_queue_status; + uint64 output_actv_queue_descriptor_size; + uint64 output_actv_queue_minimum_size; + uint64 output_actv_queue_maximum_size; + uint64 output_actv_queue_base; + uint64 output_actv_queue_status_block_base; + uint64 output_actv_queue_size; + uint64 output_actv_queue_tail; + uint64 output_actv_queue_fetched_head; + uint64 output_actv_queue_completed_head; + uint64 output_actv_queue_int_control; + uint64 output_actv_queue_int_status; + uint64 instruction_queue_control; + uint64 instruction_queue_status; + uint64 instruction_queue_descriptor_size; + uint64 instruction_queue_minimum_size; + uint64 instruction_queue_maximum_size; + uint64 instruction_queue_base; + uint64 instruction_queue_status_block_base; + uint64 instruction_queue_size; + uint64 instruction_queue_tail; + uint64 instruction_queue_fetched_head; + uint64 instruction_queue_completed_head; + uint64 instruction_queue_int_control; + uint64 instruction_queue_int_status; + uint64 input_actv_queue_control; + uint64 input_actv_queue_status; + uint64 input_actv_queue_descriptor_size; + uint64 input_actv_queue_minimum_size; + uint64 input_actv_queue_maximum_size; + uint64 input_actv_queue_base; + uint64 input_actv_queue_status_block_base; + uint64 input_actv_queue_size; + uint64 input_actv_queue_tail; + uint64 input_actv_queue_fetched_head; + uint64 input_actv_queue_completed_head; + uint64 input_actv_queue_int_control; + uint64 input_actv_queue_int_status; + uint64 param_queue_control; + uint64 param_queue_status; + uint64 param_queue_descriptor_size; + uint64 param_queue_minimum_size; + uint64 param_queue_maximum_size; + uint64 param_queue_base; + uint64 param_queue_status_block_base; + uint64 param_queue_size; + uint64 param_queue_tail; + uint64 param_queue_fetched_head; + uint64 param_queue_completed_head; + uint64 param_queue_int_control; + uint64 param_queue_int_status; + uint64 sc_host_int_control; + uint64 sc_host_int_status; + uint64 top_level_int_control; + uint64 top_level_int_status; + uint64 fatal_err_int_control; + uint64 fatal_err_int_status; + uint64 sc_host_int_count; + uint64 dma_pause; + uint64 dma_paused; + uint64 status_block_update; + uint64 hib_error_status; + uint64 hib_error_mask; + uint64 hib_first_error_status; + uint64 hib_first_error_timestamp; + uint64 hib_inject_error; + uint64 read_request_arbiter; + uint64 write_request_arbiter; + uint64 address_translation_arbiter; + uint64 sender_queue_threshold; + uint64 page_fault_address; + uint64 instruction_credits; + uint64 input_actv_credits; + uint64 param_credits; + uint64 output_actv_credits; + uint64 pause_state; + uint64 snapshot; + uint64 idle_assert; + uint64 wire_int_pending_bit_array; + uint64 tileconfig0; + uint64 tileconfig1; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_DEBUG_HIB_USER_CSR_OFFSETS_H_ diff --git a/driver/config/debug_scalar_core_csr_offsets.h b/driver/config/debug_scalar_core_csr_offsets.h new file mode 100644 index 0000000..f8ebf20 --- /dev/null +++ b/driver/config/debug_scalar_core_csr_offsets.h @@ -0,0 +1,437 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_DEBUG_SCALAR_CORE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_DEBUG_SCALAR_CORE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets that will be dumped as part of the +// driver bug report for scalar core. Members are intentionally named to match +// the GCSR register names. +struct DebugScalarCoreCsrOffsets { + uint64 topology; + uint64 scMemoryCapacity; + uint64 tileMemoryCapacity; + uint64 scMemoryAccess; + uint64 scMemoryData; + uint64 Timeout; + uint64 Error_ScalarCore; + uint64 Error_Mask_ScalarCore; + uint64 Error_Force_ScalarCore; + uint64 Error_Timestamp_ScalarCore; + uint64 Error_Info_ScalarCore; + + // Contextual CSRs in Scalar Datapaths + uint64 scalarCoreRunControl; + uint64 scalarCoreBreakPoint; + uint64 currentPc; + uint64 executeControl; + // Context 0 + uint64 scalarDatapath_0RunControl; + uint64 scalarDatapath_0BreakPoint; + uint64 currentPc_0; + uint64 executeControl_0; + // Context 1 + uint64 scalarDatapath_1RunControl; + uint64 scalarDatapath_1BreakPoint; + uint64 currentPc_1; + uint64 executeControl_1; + // Context 2 + uint64 scalarDatapath_2RunControl; + uint64 scalarDatapath_2BreakPoint; + uint64 currentPc_2; + uint64 executeControl_2; + // Context 3 + uint64 scalarDatapath_3RunControl; + uint64 scalarDatapath_3BreakPoint; + uint64 currentPc_3; + uint64 executeControl_3; + // Sync Flag CSRs + uint64 SyncCounter_AVDATA_POP; + uint64 SyncCounter_PARAMETER_POP; + uint64 SyncCounter_AVDATA_INFEED; + uint64 SyncCounter_PARAMETER_INFEED; + uint64 SyncCounter_SCALAR_INFEED; + uint64 SyncCounter_PRODUCER_A; + uint64 SyncCounter_PRODUCER_B; + uint64 SyncCounter_RING_OUTFEED; + // Context 0 + uint64 SyncCounter_AVDATA_POP_0_0; + uint64 SyncCounter_PARAMETER_POP_0_0; + uint64 SyncCounter_AVDATA_INFEED_0_0; + uint64 SyncCounter_PARAMETER_INFEED_0_0; + uint64 SyncCounter_SCALAR_INFEED_0_0; + uint64 SyncCounter_PRODUCER_A_0_0; + uint64 SyncCounter_PRODUCER_B_0_0; + uint64 SyncCounter_RING_OUTFEED_0_0; + // Context 1 + uint64 SyncCounter_AVDATA_POP_1_0; + uint64 SyncCounter_PARAMETER_POP_1_0; + uint64 SyncCounter_AVDATA_INFEED_1_0; + uint64 SyncCounter_PARAMETER_INFEED_1_0; + uint64 SyncCounter_SCALAR_INFEED_1_0; + uint64 SyncCounter_PRODUCER_A_1_0; + uint64 SyncCounter_PRODUCER_B_1_0; + uint64 SyncCounter_RING_OUTFEED_1_0; + // Context 2 + uint64 SyncCounter_AVDATA_POP_2_0; + uint64 SyncCounter_PARAMETER_POP_2_0; + uint64 SyncCounter_AVDATA_INFEED_2_0; + uint64 SyncCounter_PARAMETER_INFEED_2_0; + uint64 SyncCounter_SCALAR_INFEED_2_0; + uint64 SyncCounter_PRODUCER_A_2_0; + uint64 SyncCounter_PRODUCER_B_2_0; + uint64 SyncCounter_RING_OUTFEED_2_0; + // Context 3 + uint64 SyncCounter_AVDATA_POP_3_0; + uint64 SyncCounter_PARAMETER_POP_3_0; + uint64 SyncCounter_AVDATA_INFEED_3_0; + uint64 SyncCounter_PARAMETER_INFEED_3_0; + uint64 SyncCounter_SCALAR_INFEED_3_0; + uint64 SyncCounter_PRODUCER_A_3_0; + uint64 SyncCounter_PRODUCER_B_3_0; + uint64 SyncCounter_RING_OUTFEED_3_0; + + // Pop Input Control Units + uint64 avDataPopRunControl; + uint64 avDataPopBreakPoint; + uint64 avDataPopRunStatus; + uint64 avDataPopOverwriteMode; + uint64 avDataPopEnableTracing; + uint64 avDataPopStartCycle; + uint64 avDataPopEndCycle; + uint64 avDataPopStallCycleCount; + uint64 avDataPopProgramCounter; + uint64 avDataPopTtuStateRegFile; + uint64 avDataPopTrace; + // Context 0 + uint64 avDataPop_0RunControl; + uint64 avDataPop_0BreakPoint; + uint64 avDataPop_0RunStatus; + uint64 avDataPop_0OverwriteMode; + uint64 avDataPop_0EnableTracing; + uint64 avDataPop_0StartCycle; + uint64 avDataPop_0EndCycle; + uint64 avDataPop_0StallCycleCount; + uint64 avDataPop_0ProgramCounter; + uint64 avDataPop_0TtuStateRegFile; + uint64 avDataPop_0Trace; + // Context 1 + uint64 avDataPop_1RunControl; + uint64 avDataPop_1BreakPoint; + uint64 avDataPop_1RunStatus; + uint64 avDataPop_1OverwriteMode; + uint64 avDataPop_1EnableTracing; + uint64 avDataPop_1StartCycle; + uint64 avDataPop_1EndCycle; + uint64 avDataPop_1StallCycleCount; + uint64 avDataPop_1ProgramCounter; + uint64 avDataPop_1TtuStateRegFile; + uint64 avDataPop_1Trace; + // Context 2 + uint64 avDataPop_2RunControl; + uint64 avDataPop_2BreakPoint; + uint64 avDataPop_2RunStatus; + uint64 avDataPop_2OverwriteMode; + uint64 avDataPop_2EnableTracing; + uint64 avDataPop_2StartCycle; + uint64 avDataPop_2EndCycle; + uint64 avDataPop_2StallCycleCount; + uint64 avDataPop_2ProgramCounter; + uint64 avDataPop_2TtuStateRegFile; + uint64 avDataPop_2Trace; + // Context 3 + uint64 avDataPop_3RunControl; + uint64 avDataPop_3BreakPoint; + uint64 avDataPop_3RunStatus; + uint64 avDataPop_3OverwriteMode; + uint64 avDataPop_3EnableTracing; + uint64 avDataPop_3StartCycle; + uint64 avDataPop_3EndCycle; + uint64 avDataPop_3StallCycleCount; + uint64 avDataPop_3ProgramCounter; + uint64 avDataPop_3TtuStateRegFile; + uint64 avDataPop_3Trace; + + uint64 parameterPopRunControl; + uint64 parameterPopBreakPoint; + uint64 parameterPopRunStatus; + uint64 parameterPopOverwriteMode; + uint64 parameterPopEnableTracing; + uint64 parameterPopStartCycle; + uint64 parameterPopEndCycle; + uint64 parameterPopStallCycleCount; + uint64 parameterPopProgramCounter; + uint64 parameterPopTtuStateRegFile; + uint64 parameterPopTrace; + // Context 0 + uint64 parameterPop_0RunControl; + uint64 parameterPop_0BreakPoint; + uint64 parameterPop_0RunStatus; + uint64 parameterPop_0OverwriteMode; + uint64 parameterPop_0EnableTracing; + uint64 parameterPop_0StartCycle; + uint64 parameterPop_0EndCycle; + uint64 parameterPop_0StallCycleCount; + uint64 parameterPop_0ProgramCounter; + uint64 parameterPop_0TtuStateRegFile; + uint64 parameterPop_0Trace; + // Context 1 + uint64 parameterPop_1RunControl; + uint64 parameterPop_1BreakPoint; + uint64 parameterPop_1RunStatus; + uint64 parameterPop_1OverwriteMode; + uint64 parameterPop_1EnableTracing; + uint64 parameterPop_1StartCycle; + uint64 parameterPop_1EndCycle; + uint64 parameterPop_1StallCycleCount; + uint64 parameterPop_1ProgramCounter; + uint64 parameterPop_1TtuStateRegFile; + uint64 parameterPop_1Trace; + // Context 2 + uint64 parameterPop_2RunControl; + uint64 parameterPop_2BreakPoint; + uint64 parameterPop_2RunStatus; + uint64 parameterPop_2OverwriteMode; + uint64 parameterPop_2EnableTracing; + uint64 parameterPop_2StartCycle; + uint64 parameterPop_2EndCycle; + uint64 parameterPop_2StallCycleCount; + uint64 parameterPop_2ProgramCounter; + uint64 parameterPop_2TtuStateRegFile; + uint64 parameterPop_2Trace; + // Context 3 + uint64 parameterPop_3RunControl; + uint64 parameterPop_3BreakPoint; + uint64 parameterPop_3RunStatus; + uint64 parameterPop_3OverwriteMode; + uint64 parameterPop_3EnableTracing; + uint64 parameterPop_3StartCycle; + uint64 parameterPop_3EndCycle; + uint64 parameterPop_3StallCycleCount; + uint64 parameterPop_3ProgramCounter; + uint64 parameterPop_3TtuStateRegFile; + uint64 parameterPop_3Trace; + // Infeed Control Units + uint64 infeedRunControl; + uint64 infeedRunStatus; + uint64 infeedBreakPoint; + uint64 infeedOverwriteMode; + uint64 infeedEnableTracing; + uint64 infeedStartCycle; + uint64 infeedEndCycle; + uint64 infeedStallCycleCount; + uint64 infeedProgramCounter; + uint64 infeedTtuStateRegFile; + // Context 0 + uint64 infeed_0_0RunControl; + uint64 infeed_0_0RunStatus; + uint64 infeed_0_0BreakPoint; + uint64 infeed_0_0OverwriteMode; + uint64 infeed_0_0EnableTracing; + uint64 infeed_0_0StartCycle; + uint64 infeed_0_0EndCycle; + uint64 infeed_0_0StallCycleCount; + uint64 infeed_0_0ProgramCounter; + uint64 infeed_0_0TtuStateRegFile; + uint64 infeed_0_1RunControl; + uint64 infeed_0_1RunStatus; + uint64 infeed_0_1BreakPoint; + uint64 infeed_0_1OverwriteMode; + uint64 infeed_0_1EnableTracing; + uint64 infeed_0_1StartCycle; + uint64 infeed_0_1EndCycle; + uint64 infeed_0_1StallCycleCount; + uint64 infeed_0_1ProgramCounter; + uint64 infeed_0_1TtuStateRegFile; + // Context 1 + uint64 infeed_1_0RunControl; + uint64 infeed_1_0RunStatus; + uint64 infeed_1_0BreakPoint; + uint64 infeed_1_0OverwriteMode; + uint64 infeed_1_0EnableTracing; + uint64 infeed_1_0StartCycle; + uint64 infeed_1_0EndCycle; + uint64 infeed_1_0StallCycleCount; + uint64 infeed_1_0ProgramCounter; + uint64 infeed_1_0TtuStateRegFile; + uint64 infeed_1_1RunControl; + uint64 infeed_1_1RunStatus; + uint64 infeed_1_1BreakPoint; + uint64 infeed_1_1OverwriteMode; + uint64 infeed_1_1EnableTracing; + uint64 infeed_1_1StartCycle; + uint64 infeed_1_1EndCycle; + uint64 infeed_1_1StallCycleCount; + uint64 infeed_1_1ProgramCounter; + uint64 infeed_1_1TtuStateRegFile; + // Context 2 + uint64 infeed_2_0RunControl; + uint64 infeed_2_0RunStatus; + uint64 infeed_2_0BreakPoint; + uint64 infeed_2_0OverwriteMode; + uint64 infeed_2_0EnableTracing; + uint64 infeed_2_0StartCycle; + uint64 infeed_2_0EndCycle; + uint64 infeed_2_0StallCycleCount; + uint64 infeed_2_0ProgramCounter; + uint64 infeed_2_0TtuStateRegFile; + uint64 infeed_2_1RunControl; + uint64 infeed_2_1RunStatus; + uint64 infeed_2_1BreakPoint; + uint64 infeed_2_1OverwriteMode; + uint64 infeed_2_1EnableTracing; + uint64 infeed_2_1StartCycle; + uint64 infeed_2_1EndCycle; + uint64 infeed_2_1StallCycleCount; + uint64 infeed_2_1ProgramCounter; + uint64 infeed_2_1TtuStateRegFile; + // Context 3 + uint64 infeed_3_0RunControl; + uint64 infeed_3_0RunStatus; + uint64 infeed_3_0BreakPoint; + uint64 infeed_3_0OverwriteMode; + uint64 infeed_3_0EnableTracing; + uint64 infeed_3_0StartCycle; + uint64 infeed_3_0EndCycle; + uint64 infeed_3_0StallCycleCount; + uint64 infeed_3_0ProgramCounter; + uint64 infeed_3_0TtuStateRegFile; + uint64 infeed_3_1RunControl; + uint64 infeed_3_1RunStatus; + uint64 infeed_3_1BreakPoint; + uint64 infeed_3_1OverwriteMode; + uint64 infeed_3_1EnableTracing; + uint64 infeed_3_1StartCycle; + uint64 infeed_3_1EndCycle; + uint64 infeed_3_1StallCycleCount; + uint64 infeed_3_1ProgramCounter; + uint64 infeed_3_1TtuStateRegFile; + + // Outfeed Control Units + uint64 outfeedRunControl; + uint64 outfeedRunStatus; + uint64 outfeedBreakPoint; + uint64 outfeedOverwriteMode; + uint64 outfeedEnableTracing; + uint64 outfeedStartCycle; + uint64 outfeedEndCycle; + uint64 outfeedStallCycleCount; + uint64 outfeedProgramCounter; + uint64 outfeedTtuStateRegFile; + // Context 0 + uint64 outfeed_0_0RunControl; + uint64 outfeed_0_0RunStatus; + uint64 outfeed_0_0BreakPoint; + uint64 outfeed_0_0OverwriteMode; + uint64 outfeed_0_0EnableTracing; + uint64 outfeed_0_0StartCycle; + uint64 outfeed_0_0EndCycle; + uint64 outfeed_0_0StallCycleCount; + uint64 outfeed_0_0ProgramCounter; + uint64 outfeed_0_0TtuStateRegFile; + uint64 outfeed_0_1RunControl; + uint64 outfeed_0_1RunStatus; + uint64 outfeed_0_1BreakPoint; + uint64 outfeed_0_1OverwriteMode; + uint64 outfeed_0_1EnableTracing; + uint64 outfeed_0_1StartCycle; + uint64 outfeed_0_1EndCycle; + uint64 outfeed_0_1StallCycleCount; + uint64 outfeed_0_1ProgramCounter; + uint64 outfeed_0_1TtuStateRegFile; + // Context 1 + uint64 outfeed_1_0RunControl; + uint64 outfeed_1_0RunStatus; + uint64 outfeed_1_0BreakPoint; + uint64 outfeed_1_0OverwriteMode; + uint64 outfeed_1_0EnableTracing; + uint64 outfeed_1_0StartCycle; + uint64 outfeed_1_0EndCycle; + uint64 outfeed_1_0StallCycleCount; + uint64 outfeed_1_0ProgramCounter; + uint64 outfeed_1_0TtuStateRegFile; + uint64 outfeed_1_1RunControl; + uint64 outfeed_1_1RunStatus; + uint64 outfeed_1_1BreakPoint; + uint64 outfeed_1_1OverwriteMode; + uint64 outfeed_1_1EnableTracing; + uint64 outfeed_1_1StartCycle; + uint64 outfeed_1_1EndCycle; + uint64 outfeed_1_1StallCycleCount; + uint64 outfeed_1_1ProgramCounter; + uint64 outfeed_1_1TtuStateRegFile; + // Context 2 + uint64 outfeed_2_0RunControl; + uint64 outfeed_2_0RunStatus; + uint64 outfeed_2_0BreakPoint; + uint64 outfeed_2_0OverwriteMode; + uint64 outfeed_2_0EnableTracing; + uint64 outfeed_2_0StartCycle; + uint64 outfeed_2_0EndCycle; + uint64 outfeed_2_0StallCycleCount; + uint64 outfeed_2_0ProgramCounter; + uint64 outfeed_2_0TtuStateRegFile; + uint64 outfeed_2_1RunControl; + uint64 outfeed_2_1RunStatus; + uint64 outfeed_2_1BreakPoint; + uint64 outfeed_2_1OverwriteMode; + uint64 outfeed_2_1EnableTracing; + uint64 outfeed_2_1StartCycle; + uint64 outfeed_2_1EndCycle; + uint64 outfeed_2_1StallCycleCount; + uint64 outfeed_2_1ProgramCounter; + uint64 outfeed_2_1TtuStateRegFile; + // Context 3 + uint64 outfeed_3_0RunControl; + uint64 outfeed_3_0RunStatus; + uint64 outfeed_3_0BreakPoint; + uint64 outfeed_3_0OverwriteMode; + uint64 outfeed_3_0EnableTracing; + uint64 outfeed_3_0StartCycle; + uint64 outfeed_3_0EndCycle; + uint64 outfeed_3_0StallCycleCount; + uint64 outfeed_3_0ProgramCounter; + uint64 outfeed_3_0TtuStateRegFile; + uint64 outfeed_3_1RunControl; + uint64 outfeed_3_1RunStatus; + uint64 outfeed_3_1BreakPoint; + uint64 outfeed_3_1OverwriteMode; + uint64 outfeed_3_1EnableTracing; + uint64 outfeed_3_1StartCycle; + uint64 outfeed_3_1EndCycle; + uint64 outfeed_3_1StallCycleCount; + uint64 outfeed_3_1ProgramCounter; + uint64 outfeed_3_1TtuStateRegFile; + + // Scalar Pipeline + uint64 scalarCoreRunStatus; + uint64 scalarCoreRunStatus_0; + uint64 scalarCoreRunStatus_1; + uint64 scalarCoreRunStatus_2; + uint64 scalarCoreRunStatus_3; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_DEBUG_SCALAR_CORE_CSR_OFFSETS_H_ diff --git a/driver/config/debug_tile_csr_offsets.h b/driver/config/debug_tile_csr_offsets.h new file mode 100644 index 0000000..ad95c68 --- /dev/null +++ b/driver/config/debug_tile_csr_offsets.h @@ -0,0 +1,185 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_DEBUG_TILE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_DEBUG_TILE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets that will be dumped as part of the +// driver bug report for tiles. Members are intentionally named to match +// the GCSR register names. +struct DebugTileCsrOffsets { + uint64 TileClockControl; + uint64 tileid; + uint64 scratchpad; + uint64 memoryAccess; + uint64 memoryData; + uint64 narrowMemoryContext_0; + uint64 narrowMemoryContext_1; + uint64 narrowMemoryContext_2; + uint64 narrowMemoryContext_3; + uint64 deepSleep; + uint64 SyncCounter_AVDATA; + uint64 SyncCounter_PARAMETERS; + uint64 SyncCounter_PARTIAL_SUMS; + uint64 SyncCounter_MESH_NORTH_IN; + uint64 SyncCounter_MESH_EAST_IN; + uint64 SyncCounter_MESH_SOUTH_IN; + uint64 SyncCounter_MESH_WEST_IN; + uint64 SyncCounter_MESH_NORTH_OUT; + uint64 SyncCounter_MESH_EAST_OUT; + uint64 SyncCounter_MESH_SOUTH_OUT; + uint64 SyncCounter_MESH_WEST_OUT; + uint64 SyncCounter_WIDE_TO_NARROW; + uint64 SyncCounter_WIDE_TO_SCALING; + uint64 SyncCounter_NARROW_TO_WIDE; + uint64 SyncCounter_RING_READ_A; + uint64 SyncCounter_RING_READ_B; + uint64 SyncCounter_RING_WRITE; + uint64 SyncCounter_RING_PRODUCER_A; + uint64 SyncCounter_RING_PRODUCER_B; + uint64 opRunControl; + uint64 PowerSaveData; + uint64 opBreakPoint; + uint64 StallCounter; + uint64 opRunStatus; + uint64 OpOverwriteMode; + uint64 OpEnableTracing; + uint64 OpStartCycle; + uint64 OpEndCycle; + uint64 OpStallCycleCount; + uint64 OpProgramCounter; + uint64 wideToNarrowRunControl; + uint64 wideToNarrowRunStatus; + uint64 wideToNarrowBreakPoint; + uint64 dmaWideToNarrowOverwriteMode; + uint64 dmaWideToNarrowEnableTracing; + uint64 dmaWideToNarrowStartCycle; + uint64 dmaWideToNarrowEndCycle; + uint64 dmaWideToNarrowStallCycleCount; + uint64 dmaWideToNarrowProgramCounter; + uint64 narrowToWideRunControl; + uint64 narrowToWideRunStatus; + uint64 narrowToWideBreakPoint; + uint64 dmaNarrowToWideOverwriteMode; + uint64 dmaNarrowToWideEnableTracing; + uint64 dmaNarrowToWideStartCycle; + uint64 dmaNarrowToWideEndCycle; + uint64 dmaNarrowToWideStallCycleCount; + uint64 dmaNarrowToWideProgramCounter; + uint64 ringBusConsumer0RunControl; + uint64 ringBusConsumer0RunStatus; + uint64 ringBusConsumer0BreakPoint; + uint64 dmaRingBusConsumer0OverwriteMode; + uint64 dmaRingBusConsumer0EnableTracing; + uint64 dmaRingBusConsumer0StartCycle; + uint64 dmaRingBusConsumer0EndCycle; + uint64 dmaRingBusConsumer0StallCycleCount; + uint64 dmaRingBusConsumer0ProgramCounter; + uint64 ringBusConsumer1RunControl; + uint64 ringBusConsumer1RunStatus; + uint64 ringBusConsumer1BreakPoint; + uint64 dmaRingBusConsumer1OverwriteMode; + uint64 dmaRingBusConsumer1EnableTracing; + uint64 dmaRingBusConsumer1StartCycle; + uint64 dmaRingBusConsumer1EndCycle; + uint64 dmaRingBusConsumer1StallCycleCount; + uint64 dmaRingBusConsumer1ProgramCounter; + uint64 ringBusProducerRunControl; + uint64 ringBusProducerRunStatus; + uint64 ringBusProducerBreakPoint; + uint64 dmaRingBusProducerOverwriteMode; + uint64 dmaRingBusProducerEnableTracing; + uint64 dmaRingBusProducerStartCycle; + uint64 dmaRingBusProducerEndCycle; + uint64 dmaRingBusProducerStallCycleCount; + uint64 dmaRingBusProducerProgramCounter; + uint64 meshBus0RunControl; + uint64 meshBus0RunStatus; + uint64 meshBus0BreakPoint; + uint64 dmaMeshBus0OverwriteMode; + uint64 dmaMeshBus0EnableTracing; + uint64 dmaMeshBus0StartCycle; + uint64 dmaMeshBus0EndCycle; + uint64 dmaMeshBus0StallCycleCount; + uint64 dmaMeshBus0ProgramCounter; + uint64 meshBus1RunControl; + uint64 meshBus1RunStatus; + uint64 meshBus1BreakPoint; + uint64 dmaMeshBus1OverwriteMode; + uint64 dmaMeshBus1EnableTracing; + uint64 dmaMeshBus1StartCycle; + uint64 dmaMeshBus1EndCycle; + uint64 dmaMeshBus1StallCycleCount; + uint64 dmaMeshBus1ProgramCounter; + uint64 meshBus2RunControl; + uint64 meshBus2RunStatus; + uint64 meshBus2BreakPoint; + uint64 dmaMeshBus2OverwriteMode; + uint64 dmaMeshBus2EnableTracing; + uint64 dmaMeshBus2StartCycle; + uint64 dmaMeshBus2EndCycle; + uint64 dmaMeshBus2StallCycleCount; + uint64 dmaMeshBus2ProgramCounter; + uint64 meshBus3RunControl; + uint64 meshBus3RunStatus; + uint64 meshBus3BreakPoint; + uint64 dmaMeshBus3OverwriteMode; + uint64 dmaMeshBus3EnableTracing; + uint64 dmaMeshBus3StartCycle; + uint64 dmaMeshBus3EndCycle; + uint64 dmaMeshBus3StallCycleCount; + uint64 dmaMeshBus3ProgramCounter; + uint64 Error_Tile; + uint64 Error_Mask_Tile; + uint64 Error_Force_Tile; + uint64 Error_Timestamp_Tile; + uint64 Error_Info_Tile; + uint64 Timeout; + uint64 opTtuStateRegFile; + uint64 OpTrace; + uint64 wideToNarrowTtuStateRegFile; + uint64 dmaWideToNarrowTrace; + uint64 narrowToWideTtuStateRegFile; + uint64 dmaNarrowToWideTrace; + uint64 ringBusConsumer0TtuStateRegFile; + uint64 dmaRingBusConsumer0Trace; + uint64 ringBusConsumer1TtuStateRegFile; + uint64 dmaRingBusConsumer1Trace; + uint64 ringBusProducerTtuStateRegFile; + uint64 dmaRingBusProducerTrace; + uint64 meshBus0TtuStateRegFile; + uint64 dmaMeshBus0Trace; + uint64 meshBus1TtuStateRegFile; + uint64 dmaMeshBus1Trace; + uint64 meshBus2TtuStateRegFile; + uint64 dmaMeshBus2Trace; + uint64 meshBus3TtuStateRegFile; + uint64 dmaMeshBus3Trace; + uint64 narrowMemoryIsolation; + uint64 narrowMemoryRetention; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_DEBUG_TILE_CSR_OFFSETS_H_ diff --git a/driver/config/hib_kernel_csr_offsets.h b/driver/config/hib_kernel_csr_offsets.h new file mode 100644 index 0000000..d793264 --- /dev/null +++ b/driver/config/hib_kernel_csr_offsets.h @@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_HIB_KERNEL_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_HIB_KERNEL_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for kernel space in HIB. Members are +// intentionally named to match the GCSR register names. +struct HibKernelCsrOffsets { + uint64 page_table_size; + uint64 extended_table; + uint64 dma_pause; + + // Tracks whether initialization is done. + uint64 page_table_init; + uint64 msix_table_init; + + // Points to the first entry in the page table. Subsequent entries can be + // accessed with increasing offsets if they exist. + uint64 page_table; + + // Limits AXI DMA burst. + uint64 dma_burst_limiter; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_HIB_KERNEL_CSR_OFFSETS_H_ diff --git a/driver/config/hib_user_csr_offsets.h b/driver/config/hib_user_csr_offsets.h new file mode 100644 index 0000000..9ff44b0 --- /dev/null +++ b/driver/config/hib_user_csr_offsets.h @@ -0,0 +1,58 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_HIB_USER_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_HIB_USER_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for user space in HIB. Members are +// intentionally named to match the GCSR register names. +struct HibUserCsrOffsets { + // Interrupt control and status for top level. + uint64 top_level_int_control; + uint64 top_level_int_status; + + // Interrupt count for scalar core. + uint64 sc_host_int_count; + + // DMA pauses. + uint64 dma_pause; + uint64 dma_paused; + + // Enable/disable status block update. + uint64 status_block_update; + + // HIB errors. + uint64 hib_error_status; + uint64 hib_error_mask; + uint64 hib_first_error_status; + uint64 hib_first_error_timestamp; + uint64 hib_inject_error; + + // Limits AXI DMA burst. + uint64 dma_burst_limiter; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_HIB_USER_CSR_OFFSETS_H_ diff --git a/driver/config/interrupt_csr_offsets.h b/driver/config/interrupt_csr_offsets.h new file mode 100644 index 0000000..e2445d8 --- /dev/null +++ b/driver/config/interrupt_csr_offsets.h @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_INTERRUPT_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_INTERRUPT_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for enabling/disabling/clearing +// interrupts. Members are intentionally named to match the GCSR register +// names. +struct InterruptCsrOffsets { + // Interrupt control and status CSRs. + uint64 control; + uint64 status; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_INTERRUPT_CSR_OFFSETS_H_ diff --git a/driver/config/memory_csr_offsets.h b/driver/config/memory_csr_offsets.h new file mode 100644 index 0000000..ca034cf --- /dev/null +++ b/driver/config/memory_csr_offsets.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_MEMORY_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_MEMORY_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for configuring memory accesses for +// scalar core and tiles. +// Members are intentionally named to match the GCSR register names. +struct MemoryCsrOffsets { + uint64 Access; + uint64 Data; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_MEMORY_CSR_OFFSETS_H_ diff --git a/driver/config/misc_csr_offsets.h b/driver/config/misc_csr_offsets.h new file mode 100644 index 0000000..4576177 --- /dev/null +++ b/driver/config/misc_csr_offsets.h @@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_MISC_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_MISC_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for custom block misc. +// Members are intentionally named to match the GCSR register names. +struct MiscCsrOffsets { + uint64 idleRegister; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_MISC_CSR_OFFSETS_H_ diff --git a/driver/config/msix_csr_offsets.h b/driver/config/msix_csr_offsets.h new file mode 100644 index 0000000..b872eeb --- /dev/null +++ b/driver/config/msix_csr_offsets.h @@ -0,0 +1,61 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_MSIX_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_MSIX_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for programming MSIX interrupts. +// Members are intentionally named to match the GCSR register names. +struct MsixCsrOffsets { + // Interrupt vectors for descriptor queues. + uint64 instruction_queue_int_vector; + uint64 input_actv_queue_int_vector; + uint64 param_queue_int_vector; + uint64 output_actv_queue_int_vector; + + // top_level_int_vector contains indices for four interrupts. + // [27:21] | [20:14] | [13:7] | [6:0] + // thermal | mbist | PCIe | thermalShutdown + uint64 top_level_int_vector; + + // sc_host_int_vector contains indices for four interrupts. + // [27:21] | [20:14] | [13:7] | [6:0] + // INT_3 | INT_2 | INT_1 | INT_0 + uint64 sc_host_int_vector; + + // HIB fatal error. + uint64 fatal_err_int_vector; + + // Points to the first bit array. Subsequent bit arrays can be accessed with + // increasing offsets if they exist. + uint64 msix_pending_bit_array0; + + // Points to the first entry in the table. Subsequent entries can be accessed + // with increasing offsets if they exist. + uint64 msix_table; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_MSIX_CSR_OFFSETS_H_ diff --git a/driver/config/power_throttle_csr_helper.h b/driver/config/power_throttle_csr_helper.h new file mode 100644 index 0000000..31d251e --- /dev/null +++ b/driver/config/power_throttle_csr_helper.h @@ -0,0 +1,220 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_POWER_THROTTLE_CSR_HELPER_H_ +#define DARWINN_DRIVER_CONFIG_POWER_THROTTLE_CSR_HELPER_H_ + +#include "driver/bitfield.h" +#include "port/integral_types.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { +namespace registers { + +// Implements field level access for the EneryTable CSR register. +class EnergyTable { + public: + EnergyTable() : EnergyTable(/*value=*/0ULL) {} + EnergyTable(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_idle_power(uint64 value) { reg_.idle_power_ = value; } + uint64 idle_power() const { return reg_.idle_power_(); } + + void set_ring_bus(uint64 value) { reg_.ring_bus_ = value; } + uint64 ring_bus() const { return reg_.ring_bus_(); } + + void set_nlu_active(uint64 value) { reg_.nlu_active_ = value; } + uint64 nlu_active() const { return reg_.nlu_active_(); } + + void set_wide_byte_access(uint64 value) { reg_.wide_byte_access_ = value; } + uint64 wide_byte_access() const { return reg_.wide_byte_access_(); } + + void set_narrow_mem_word_access(uint64 value) { + reg_.narrow_mem_word_access_ = value; + } + uint64 narrow_mem_word_access() const { + return reg_.narrow_mem_word_access_(); + } + + void set_int_multiply(uint64 value) { reg_.int_multiply_ = value; } + uint64 int_multiply() const { return reg_.int_multiply_(); } + + void set_int_adder(uint64 value) { reg_.int_adder_ = value; } + uint64 int_adder() const { return reg_.int_adder_(); } + + void set_float_32_adder(uint64 value) { reg_.float_32_adder_ = value; } + uint64 float_32_adder() const { return reg_.float_32_adder_(); } + + void set_input_bus_transfer(uint64 value) { + reg_.input_bus_transfer_ = value; + } + uint64 input_bus_transfer() const { return reg_.input_bus_transfer_(); } + + void set_wide_data_transfer(uint64 value) { + reg_.wide_data_transfer_ = value; + } + uint64 wide_data_transfer() const { return reg_.wide_data_transfer_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 2> idle_power_; + platforms::darwinn::driver::Bitfield<2, 2> ring_bus_; + platforms::darwinn::driver::Bitfield<4, 2> nlu_active_; + platforms::darwinn::driver::Bitfield<6, 2> wide_byte_access_; + platforms::darwinn::driver::Bitfield<8, 2> narrow_mem_word_access_; + platforms::darwinn::driver::Bitfield<10, 2> int_multiply_; + platforms::darwinn::driver::Bitfield<12, 2> int_adder_; + platforms::darwinn::driver::Bitfield<14, 2> float_32_adder_; + platforms::darwinn::driver::Bitfield<16, 2> input_bus_transfer_; + platforms::darwinn::driver::Bitfield<18, 2> wide_data_transfer_; + } reg_; +}; + +// Implements field level access for *SampleInterval CSR. +class SampleInterval { + public: + SampleInterval() : SampleInterval(/*value=*/0ULL) {} + SampleInterval(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_value(uint64 value) { reg_.value_ = value; } + uint64 value() const { return reg_.value_(); } + + void set_enable(uint64 value) { reg_.enable_ = value; } + uint64 enable() const { return reg_.enable_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 7> value_; + platforms::darwinn::driver::Bitfield<7, 1> enable_; + } reg_; +}; + +// Implements field level access for tdpSampleInterval CSR. +class tdpSampleInterval { + public: + tdpSampleInterval() : tdpSampleInterval(/*value=*/0ULL) {} + tdpSampleInterval(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_value(uint64 value) { reg_.value_ = value; } + uint64 value() const { return reg_.value_(); } + + void set_enable(uint64 value) { reg_.enable_ = value; } + uint64 enable() const { return reg_.enable_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 14> value_; + platforms::darwinn::driver::Bitfield<14, 1> enable_; + } reg_; +}; + +// Implements field level access for *RunningSumInterval CSR. +class RunningSumInterval { + public: + RunningSumInterval() : RunningSumInterval(/*value=*/0ULL) {} + RunningSumInterval(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_value(uint64 value) { reg_.value_ = value; } + uint64 value() const { return reg_.value_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 2> value_; + } reg_; +}; + +// Implements field level access for didtThreshold CSR. +class didtThreshold { + public: + didtThreshold() : didtThreshold(/*value=*/0ULL) {} + didtThreshold(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_value(uint64 value) { reg_.value_ = value; } + uint64 value() const { return reg_.value_(); } + + void set_enable(uint64 value) { reg_.enable_ = value; } + uint64 enable() const { return reg_.enable_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 11> value_; + platforms::darwinn::driver::Bitfield<11, 1> enable_; + } reg_; +}; + +// Implements field level access for *ActionTable CSR. +class ActionTable { + public: + ActionTable() : ActionTable(/*value=*/0ULL) {} + ActionTable(uint64 value) { reg_.raw_ = value; } + + void set_raw(uint64 value) { reg_.raw_ = value; } + uint64 raw() const { return reg_.raw_; } + + void set_action0(uint64 value) { reg_.action0_ = value; } + uint64 action0() const { return reg_.action0_(); } + + void set_action1(uint64 value) { reg_.action1_ = value; } + uint64 action1() const { return reg_.action1_(); } + + void set_action2(uint64 value) { reg_.action2_ = value; } + uint64 action2() const { return reg_.action2_(); } + + void set_action3(uint64 value) { reg_.action3_ = value; } + uint64 action3() const { return reg_.action3_(); } + + void set_enable(uint64 value) { reg_.enable_ = value; } + uint64 enable() const { return reg_.enable_(); } + + private: + union { + uint64 raw_; + platforms::darwinn::driver::Bitfield<0, 3> action0_; + platforms::darwinn::driver::Bitfield<3, 3> action1_; + platforms::darwinn::driver::Bitfield<6, 3> action2_; + platforms::darwinn::driver::Bitfield<9, 3> action3_; + platforms::darwinn::driver::Bitfield<12, 1> enable_; + } reg_; +}; + +} // namespace registers +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_POWER_THROTTLE_CSR_HELPER_H_ diff --git a/driver/config/queue_csr_offsets.h b/driver/config/queue_csr_offsets.h new file mode 100644 index 0000000..860a773 --- /dev/null +++ b/driver/config/queue_csr_offsets.h @@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_QUEUE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_QUEUE_CSR_OFFSETS_H_ + +#include + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds the various CSR offsets for programming queue behaviors. +// Members are intentionally named to match the GCSR register names. +struct QueueCsrOffsets { + uint64 queue_control; + uint64 queue_status; + uint64 queue_descriptor_size; + uint64 queue_base; + uint64 queue_status_block_base; + uint64 queue_size; + uint64 queue_tail; + uint64 queue_fetched_head; + uint64 queue_completed_head; + uint64 queue_int_control; + uint64 queue_int_status; + uint64 queue_minimum_size; + uint64 queue_maximum_size; + uint64 queue_int_vector; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_QUEUE_CSR_OFFSETS_H_ diff --git a/driver/config/register_constants.h b/driver/config/register_constants.h new file mode 100644 index 0000000..87f10da --- /dev/null +++ b/driver/config/register_constants.h @@ -0,0 +1,23 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_REGISTER_CONSTANTS_H_ +#define DARWINN_DRIVER_CONFIG_REGISTER_CONSTANTS_H_ + +#include "port/integral_types.h" + +// Offset used when a register does not exist for a project. +#define kCsrRegisterSpaceInvalidOffset static_cast(-1) + +#endif // DARWINN_DRIVER_CONFIG_REGISTER_CONSTANTS_H_ diff --git a/driver/config/register_file_csr_offsets.h b/driver/config/register_file_csr_offsets.h new file mode 100644 index 0000000..ed51c7b --- /dev/null +++ b/driver/config/register_file_csr_offsets.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_REGISTER_FILE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_REGISTER_FILE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds CSR offsets for accessing register file. Members are +// intentionally named to match the GCSR register names. +struct RegisterFileCsrOffsets { + // Points to first register in the register file. Subsequent registers can be + // accessed with increasing offsets if they exist. + uint64 RegisterFile; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_REGISTER_FILE_CSR_OFFSETS_H_ diff --git a/driver/config/scalar_core_csr_offsets.h b/driver/config/scalar_core_csr_offsets.h new file mode 100644 index 0000000..bf872d0 --- /dev/null +++ b/driver/config/scalar_core_csr_offsets.h @@ -0,0 +1,127 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_SCALAR_CORE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_SCALAR_CORE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for scalar core. +// Members are intentionally named to match the GCSR register names. +struct ScalarCoreCsrOffsets { + // RunControls. + // Legacy + uint64 scalarCoreRunControl; + uint64 executeControl; + uint64 avDataPopRunControl; + uint64 parameterPopRunControl; + + // Context 0 + uint64 scalarDatapath_0RunControl; + uint64 executeControl_0; + uint64 avDataPop_0RunControl; + uint64 parameterPop_0RunControl; + // Context 1 + uint64 scalarDatapath_1RunControl; + uint64 executeControl_1; + uint64 avDataPop_1RunControl; + uint64 parameterPop_1RunControl; + // Context 2 + uint64 scalarDatapath_2RunControl; + uint64 executeControl_2; + uint64 avDataPop_2RunControl; + uint64 parameterPop_2RunControl; + // Context 3 + uint64 scalarDatapath_3RunControl; + uint64 executeControl_3; + uint64 avDataPop_3RunControl; + uint64 parameterPop_3RunControl; + + // Legacy + uint64 infeedRunControl; + uint64 outfeedRunControl; + uint64 infeed1RunControl; + uint64 outfeed1RunControl; + + // Context Switching + uint64 contextControl; + uint64 contextStatus; + + // Context 0 + uint64 infeed_0_0RunControl; + uint64 outfeed_0_0RunControl; + uint64 infeed_0_1RunControl; + uint64 outfeed_0_1RunControl; + // Context 1 + uint64 infeed_1_0RunControl; + uint64 outfeed_1_0RunControl; + uint64 infeed_1_1RunControl; + uint64 outfeed_1_1RunControl; + // Context 2 + uint64 infeed_2_0RunControl; + uint64 outfeed_2_0RunControl; + uint64 infeed_2_1RunControl; + uint64 outfeed_2_1RunControl; + // Context 3 + uint64 infeed_3_0RunControl; + uint64 outfeed_3_0RunControl; + uint64 infeed_3_1RunControl; + uint64 outfeed_3_1RunControl; + + // Power related. + uint64 TilePowerInterval; + uint64 peakPowerSampleInterval; + uint64 tdpPowerSampleInterval; + uint64 didtPowerSampleInterval; + uint64 peakSampleAccumulator; + uint64 tdpSampleAccumulator; + uint64 didtSampleAccumulator; + uint64 peakThreshold0; + uint64 peakThreshold1; + uint64 peakThreshold2; + uint64 peakThreshold3; + uint64 tdpThreshold0; + uint64 tdpThreshold1; + uint64 tdpThreshold2; + uint64 tdpThreshold3; + uint64 didtThreshold0; + uint64 peakActionTable; + uint64 tdpActionTable; + uint64 didtActionTable; + uint64 peakRunningSum; + uint64 peakRunningSumInterval; + uint64 tdpRunningSum; + uint64 tdpRunningSumInterval; + uint64 didtRunningSum; + uint64 didtRunningSumInterval; + uint64 didtDifference; + uint64 packageTdpAction; + uint64 ThrottleStallCounter; + + // Scalar core cycle count. This could be used to synchronize timestamp + // between host and the TPU + uint64 cycleCount; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_SCALAR_CORE_CSR_OFFSETS_H_ diff --git a/driver/config/scu_csr_offsets.h b/driver/config/scu_csr_offsets.h new file mode 100644 index 0000000..b80d66e --- /dev/null +++ b/driver/config/scu_csr_offsets.h @@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_SCU_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_SCU_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for programming the SCU in Beagle +// (the block containing state machines to control boot and power sequences) +// Members are intentionally named to match the GCSR register names. +struct ScuCsrOffsets { + // The SCU control registers have generic names but each contain many small + // fields which are reflected in the spec (should use csr_helper to access) + uint64 scu_ctrl_0; + uint64 scu_ctrl_1; + uint64 scu_ctrl_2; + uint64 scu_ctrl_3; + uint64 scu_ctrl_4; + uint64 scu_ctrl_5; + uint64 scu_ctr_6; + uint64 scu_ctr_7; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_SCU_CSR_OFFSETS_H_ diff --git a/driver/config/sync_flag_csr_offsets.h b/driver/config/sync_flag_csr_offsets.h new file mode 100644 index 0000000..dc9fc0a --- /dev/null +++ b/driver/config/sync_flag_csr_offsets.h @@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_SYNC_FLAG_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_SYNC_FLAG_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for accessing sync flags for scalar +// core and tiles. +struct SyncFlagCsrOffsets { + // Intentionally named to match the GCSR register names. + uint64 SyncCounter; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_SYNC_FLAG_CSR_OFFSETS_H_ diff --git a/driver/config/tile_config_csr_offsets.h b/driver/config/tile_config_csr_offsets.h new file mode 100644 index 0000000..8a09f24 --- /dev/null +++ b/driver/config/tile_config_csr_offsets.h @@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_TILE_CONFIG_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_TILE_CONFIG_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for configuring indirect tile accesses. +// Members are intentionally named to match the GCSR register names. +struct TileConfigCsrOffsets { + // Used only by driver. + uint64 tileconfig0; + + // Used by debugger, and other purposes. + uint64 tileconfig1; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_TILE_CONFIG_CSR_OFFSETS_H_ diff --git a/driver/config/tile_csr_offsets.h b/driver/config/tile_csr_offsets.h new file mode 100644 index 0000000..d234ede --- /dev/null +++ b/driver/config/tile_csr_offsets.h @@ -0,0 +1,98 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_TILE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_TILE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for tiles. Members are intentionally +// named to match the GCSR register names. +struct TileCsrOffsets { + // RunControls to change run state. + uint64 opRunControl; + uint64 narrowToNarrowRunControl; + uint64 narrowToWideRunControl; + uint64 wideToNarrowRunControl; + // When we enable the wider thread issue feature, we get multiple + // of these run controls per pipeline for the opcontrol, narrow to wide + // and wide to narrow. We're using 8 of these as a maximum issue width + // at this point. The driver will only use the registers that are valid + // for any given configuration. + // TODO + uint64 opRunControl_0; + uint64 narrowToWideRunControl_0; + uint64 wideToNarrowRunControl_0; + uint64 opRunControl_1; + uint64 narrowToWideRunControl_1; + uint64 wideToNarrowRunControl_1; + uint64 opRunControl_2; + uint64 narrowToWideRunControl_2; + uint64 wideToNarrowRunControl_2; + uint64 opRunControl_3; + uint64 narrowToWideRunControl_3; + uint64 wideToNarrowRunControl_3; + uint64 opRunControl_4; + uint64 narrowToWideRunControl_4; + uint64 wideToNarrowRunControl_4; + uint64 opRunControl_5; + uint64 narrowToWideRunControl_5; + uint64 wideToNarrowRunControl_5; + uint64 opRunControl_6; + uint64 narrowToWideRunControl_6; + uint64 wideToNarrowRunControl_6; + uint64 opRunControl_7; + uint64 narrowToWideRunControl_7; + uint64 wideToNarrowRunControl_7; + uint64 ringBusConsumer0RunControl; + uint64 ringBusConsumer1RunControl; + uint64 ringBusProducerRunControl; + uint64 meshBus0RunControl; + uint64 meshBus1RunControl; + uint64 meshBus2RunControl; + uint64 meshBus3RunControl; + + // Deep sleep register to control power state. + uint64 deepSleep; + + // Narrow memory retention and isolation. + uint64 narrowMemoryIsolation; + uint64 narrowMemoryRetention; + + // Power related. + uint64 EnergyTable; + uint64 didtSampleInterval; + uint64 didtRunningSumInterval; + uint64 opAccumulateRegister; + uint64 didtRunningSumRegister; + uint64 didtThreshold0; + + // Narrow memory base and bound of virtual contexts. + uint64 narrowMemoryContext_0; + uint64 narrowMemoryContext_1; + uint64 narrowMemoryContext_2; + uint64 narrowMemoryContext_3; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_TILE_CSR_OFFSETS_H_ diff --git a/driver/config/tile_thread_csr_offsets.h b/driver/config/tile_thread_csr_offsets.h new file mode 100644 index 0000000..da6a1ba --- /dev/null +++ b/driver/config/tile_thread_csr_offsets.h @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_TILE_THREAD_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_TILE_THREAD_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for tilethreads. +// Members are intentionally named to match the GCSR register names. +struct TileThreadCsrOffsets { + // RunControls to change run state. + uint64 opRunControl_0; + uint64 narrowToWideRunControl_0; + uint64 wideToNarrowRunControl_0; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_TILE_THREAD_CSR_OFFSETS_H_ diff --git a/driver/config/tile_thread_trace_csr_offsets.h b/driver/config/tile_thread_trace_csr_offsets.h new file mode 100644 index 0000000..646e77c --- /dev/null +++ b/driver/config/tile_thread_trace_csr_offsets.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_TILE_THREAD_TRACE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_TILE_THREAD_TRACE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for tilethreads. +// Members are intentionally named to match the GCSR register names. +struct TileThreadTraceCsrOffsets { + // RunControls to change run state. + uint64 TimeStampUnit; + uint64 StallCauseSelect; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_TILE_THREAD_TRACE_CSR_OFFSETS_H_ diff --git a/driver/config/trace_csr_offsets.h b/driver/config/trace_csr_offsets.h new file mode 100644 index 0000000..79f1740 --- /dev/null +++ b/driver/config/trace_csr_offsets.h @@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_TRACE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_TRACE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for performance tracing. +// Members are intentionally named to match the GCSR register names. +struct TraceCsrOffsets { + uint64 OverwriteMode; + uint64 EnableTracing; + uint64 Trace; + uint64 TimeStampUnit; + uint64 StallCauseSelect; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_TRACE_CSR_OFFSETS_H_ diff --git a/driver/config/usb_csr_offsets.h b/driver/config/usb_csr_offsets.h new file mode 100644 index 0000000..9dbafd4 --- /dev/null +++ b/driver/config/usb_csr_offsets.h @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_USB_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_USB_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for USB HIB. +// Members are intentionally named to match the GCSR register names. +struct UsbCsrOffsets { + uint64 outfeed_chunk_length; + uint64 descr_ep; + uint64 ep_status_credit; + uint64 multi_bo_ep; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_USB_CSR_OFFSETS_H_ diff --git a/driver/config/wire_csr_offsets.h b/driver/config/wire_csr_offsets.h new file mode 100644 index 0000000..5cd5755 --- /dev/null +++ b/driver/config/wire_csr_offsets.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_CONFIG_WIRE_CSR_OFFSETS_H_ +#define DARWINN_DRIVER_CONFIG_WIRE_CSR_OFFSETS_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace config { + +// This struct holds various CSR offsets for handling wire interrupts. +// Members are intentionally named to match the GCSR register names. +struct WireCsrOffsets { + // Tells which interrupts should be serviced. + uint64 wire_int_pending_bit_array; + uint64 wire_int_mask_array; +}; + +} // namespace config +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_CONFIG_WIRE_CSR_OFFSETS_H_ diff --git a/driver/default_telemeter.h b/driver/default_telemeter.h new file mode 100644 index 0000000..5f6019c --- /dev/null +++ b/driver/default_telemeter.h @@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DEFAULT_TELEMETER_H_ +#define DARWINN_DRIVER_DEFAULT_TELEMETER_H_ + +#include "api/telemeter_interface.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// This is the default implementation of TelemeterInterface. By default, every +// operation is a NOP. +class DefaultTelemeter : public api::TelemeterInterface { + public: + void LogWatchdogTimeout( + const api::ExecutionContextInterface& context) override {} +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DEFAULT_TELEMETER_H_ diff --git a/driver/device_buffer.cc b/driver/device_buffer.cc new file mode 100644 index 0000000..d7c394d --- /dev/null +++ b/driver/device_buffer.cc @@ -0,0 +1,76 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/device_buffer.h" + +#include + +#include "port/integral_types.h" +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +void DeviceBuffer::Clear() { + type_ = Type::kInvalid; + size_bytes_ = 0; + device_address_ = 0; +} + +DeviceBuffer::DeviceBuffer(uint64 device_address, size_t size_bytes) + : type_(Type::kDefault), + size_bytes_(size_bytes), + device_address_(device_address) {} + +bool DeviceBuffer::operator==(const DeviceBuffer& rhs) const { + return type_ == rhs.type_ && size_bytes_ == rhs.size_bytes_ && + device_address_ == rhs.device_address_; +} + +bool DeviceBuffer::operator!=(const DeviceBuffer& rhs) const { + return !(*this == rhs); +} + +DeviceBuffer::DeviceBuffer(DeviceBuffer&& other) + : type_(other.type_), + size_bytes_(other.size_bytes_), + device_address_(other.device_address_) { + other.Clear(); +} + +DeviceBuffer& DeviceBuffer::operator=(DeviceBuffer&& other) { + if (this != &other) { + type_ = other.type_; + size_bytes_ = other.size_bytes_; + device_address_ = other.device_address_; + + other.Clear(); + } + return *this; +} + +DeviceBuffer DeviceBuffer::Slice(uint64 byte_offset, size_t size_bytes, + bool allow_overflow) const { + if (!allow_overflow) { + CHECK_LE(byte_offset + size_bytes, size_bytes_) + << "Overflowed underlying DeviceBuffer"; + } + const uint64 new_device_address = device_address_ + byte_offset; + return DeviceBuffer(new_device_address, size_bytes); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/device_buffer.h b/driver/device_buffer.h new file mode 100644 index 0000000..1b5d7d4 --- /dev/null +++ b/driver/device_buffer.h @@ -0,0 +1,99 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DEVICE_BUFFER_H_ +#define DARWINN_DRIVER_DEVICE_BUFFER_H_ + +#include +#include +#include +#include + +#include "port/integral_types.h" +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Abstracts a device addressable buffer. Movable and copyable. +class DeviceBuffer { + public: + // Convenience structure for keeping track of named array of DeviceBuffers. + using NamedMap = std::unordered_map>; + + // Default constructor. Defaults to an invalid non-existent buffer. + DeviceBuffer() = default; + + // Constructors for a device accessible buffer. + DeviceBuffer(uint64 device_address, size_t size_bytes); + + // This type is copyable, with default implementations. + DeviceBuffer(const DeviceBuffer&) = default; + DeviceBuffer& operator=(const DeviceBuffer&) = default; + + // This type is movable. + DeviceBuffer(DeviceBuffer&& other); + DeviceBuffer& operator=(DeviceBuffer&& other); + + // Destructors. + ~DeviceBuffer() = default; + + // Size of this buffer in bytes. + size_t size_bytes() const { return size_bytes_; } + + // Returns true if buffer is valid. + bool IsValid() const { return type_ != Type::kInvalid; } + + // Returns the device address. + uint64 device_address() const { return device_address_; } + + // Equality operators. + bool operator==(const DeviceBuffer& rhs) const; + bool operator!=(const DeviceBuffer& rhs) const; + + // Returns a DeviceBuffer that starts from the "byte_offset" and consumes + // "size_bytes". Internally fails if created DeviceBuffer accesses outside of + // current DeviceBuffer and "allow_overflow" is false (the default). + DeviceBuffer Slice(uint64 byte_offset, size_t size_bytes, + bool allow_overflow = false) const; + + private: + // Type for the buffer. + enum class Type { + // Invalid. + kInvalid = 0, + + // Default device buffer (only one type for now.) + kDefault = 1, + }; + + // Clears all variables. + void Clear(); + + // Type for the buffer. + Type type_{Type::kInvalid}; + + // Size of the buffer. + size_t size_bytes_{0}; + + // Points to device addressable buffer. Valid when type is kDeviceBuffer. + uint64 device_address_{0}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DEVICE_BUFFER_H_ diff --git a/driver/device_buffer_mapper.cc b/driver/device_buffer_mapper.cc new file mode 100644 index 0000000..51ac8f5 --- /dev/null +++ b/driver/device_buffer_mapper.cc @@ -0,0 +1,249 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/device_buffer_mapper.h" + +#include +#include +#include +#include + +#include "api/buffer.h" +#include "driver/hardware_structures.h" +#include "driver/memory/address_utilities.h" +#include "driver/memory/dma_direction.h" +#include "port/cleanup.h" +#include "port/logging.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +DeviceBufferMapper::DeviceBufferMapper(AddressSpace* address_space) + : address_space_(address_space) { + CHECK(address_space != nullptr); +} + +util::Status DeviceBufferMapper::UnmapAll() { + TRACE_SCOPE("DeviceBufferMapper::UnmapAll"); + + RETURN_IF_ERROR(UnmapMultiple(instruction_mappings_)); + RETURN_IF_ERROR(Unmap(std::move(scratch_))); + RETURN_IF_ERROR(UnmapMultiple(input_mappings_)); + RETURN_IF_ERROR(UnmapMultiple(output_mappings_)); + + inputs_.clear(); + input_mappings_.clear(); + outputs_.clear(); + output_mappings_.clear(); + instructions_.clear(); + instruction_mappings_.clear(); + return util::Status(); // OK +} + +util::Status DeviceBufferMapper::MapInputs(const Buffer::NamedMap& buffers) { + TRACE_SCOPE("DeviceBufferMapper::MapInputs"); + return MapMultiple(buffers, DmaDirection::kToDevice, inputs_, + input_mappings_); +} + +util::Status DeviceBufferMapper::MapOutputs(const Buffer::NamedMap& buffers) { + TRACE_SCOPE("DeviceBufferMapper::MapOutputs"); + return MapMultiple(buffers, DmaDirection::kFromDevice, outputs_, + output_mappings_); +} + +util::Status DeviceBufferMapper::MapScratch(const Buffer& buffer) { + TRACE_SCOPE("DeviceBufferMapper::MapScratch"); + DCHECK(!scratch_.IsValid()); + ASSIGN_OR_RETURN(scratch_, Map(buffer, DmaDirection::kBidirectional)); + + VLOG(3) << StringPrintf( + "Mapped scratch : %s -> 0x%016llx, %zu bytes.", buffer.ToString().c_str(), + static_cast( // NOLINT(runtime/int) + scratch_.device_address()), + scratch_.size_bytes()); + + return util::Status(); // OK +} + +util::Status DeviceBufferMapper::MapInstructions( + const std::vector& buffers) { + TRACE_SCOPE("DeviceBufferMapper::MapInstructions"); + if (!instruction_mappings_.empty()) { + return util::InvalidArgumentError("Instructions are already mapped."); + } + + static const std::string kInstructions = "instructions"; + + // For convenience, place the instructions in a NamedMap just like inputs or + // outputs. + Buffer::NamedMap map; + map[kInstructions] = buffers; + + DeviceBuffer::NamedMap device_map; + const util::Status ret = MapMultiple(map, DmaDirection::kToDevice, device_map, + instruction_mappings_); + instructions_ = std::move(device_map[kInstructions]); + return ret; +} + +util::StatusOr DeviceBufferMapper::Map(const Buffer& buffer, + DmaDirection direction) { + TRACE_SCOPE("DeviceBufferMapper::Map"); + if (buffer.IsValid()) { + return address_space_->MapMemory(buffer, direction, MappingTypeHint::kAny); + } + return DeviceBuffer(); // Invalid buffer. +} + +util::Status DeviceBufferMapper::Unmap(DeviceBuffer buffer) { + TRACE_SCOPE("DeviceBufferMapper::Unmap"); + if (buffer.IsValid()) { + return address_space_->UnmapMemory(std::move(buffer)); + } + return util::Status(); // OK +} + +util::Status DeviceBufferMapper::MapMultiple( + const Buffer::NamedMap& buffers, DmaDirection direction, + /*out*/ DeviceBuffer::NamedMap& user_buffers, + /*out*/ std::vector& mapped_buffers) { + if (!user_buffers.empty() || !mapped_buffers.empty()) { + return util::InvalidArgumentError("Device buffer is already mapped."); + } + + auto cleaner = MakeCleanup( + [this, &mapped_buffers] { CHECK_OK(UnmapMultiple(mapped_buffers)); }); + + // Separate the buffers into ptr- and non-ptr types. + std::vector ptr_buffers; + for (const auto& name_and_buffer : buffers) { + for (const auto& buffer : name_and_buffer.second) { + if (buffer.IsPtrType()) { + ptr_buffers.push_back(buffer); + } + } + } + + // Coalesce adjacent buffers. Since the underlying implementation can only map + // whole pages, any buffers on the same page or adjacent pages can be merged + // into a single underlying Map call. The basic algorithm is as follows: + // + // 1. Create a vector containing all start and end points, keeping a tag + // on each element indicating whether it was a start or end. + // 2. Sort the vector, and if a start and end point have the same address, the + // start point should be first in sorted order. + // 3. Iterate over the vector. Keep a running count of #start-#end points + // seen. Whenever this counter hits zero, that's the end of a merged + // interval. + // + // Because all the addresses are page-aligned, we can use the low bit to + // distinguish between the start and end points. + + constexpr uint64 kEndOfMappingBit = 1; + + std::vector addresses; + addresses.reserve(ptr_buffers.size() * 2); + + // merged_intervals contains the start address of each merged interval. + // Pre-allocate space assuming that no merging will happen. + std::vector merged_intervals; + merged_intervals.reserve(ptr_buffers.size()); + + for (const auto& buffer : ptr_buffers) { + uint64 start = GetPageAddress(reinterpret_cast(buffer.ptr())); + uint64 end = + start + + GetNumberPages(buffer.ptr(), buffer.size_bytes()) * kHostPageSize + + kEndOfMappingBit; + addresses.push_back(start); + addresses.push_back(end); + } + + std::sort(addresses.begin(), addresses.end()); + + int count = 0; + for (uint64 address : addresses) { + if (address & kEndOfMappingBit) { + --count; + CHECK_GE(count, 0); + if (count == 0) { + uint8* start = merged_intervals.back(); + uint8* end = reinterpret_cast(address - kEndOfMappingBit); + Buffer merged_buffer(start, end - start); + ASSIGN_OR_RETURN(auto device_buffer, Map(merged_buffer, direction)); + mapped_buffers.push_back(device_buffer); + } + } else { + if (count == 0) { + merged_intervals.push_back(reinterpret_cast(address)); + } + ++count; + } + } + + // Figure out where the user's device buffers are within the merged buffers. + for (const auto& name_and_buffer : buffers) { + for (const auto& buffer : name_and_buffer.second) { + DeviceBuffer device_buffer; + if (buffer.IsPtrType()) { + // Find the index of the corresponding merged buffer. In C++, there is + // no way to directly binary search for an element that's less than a + // given value, so instead we look for the closest one that's strictly + // greater and subtract one from the index. + const auto next = std::upper_bound( + merged_intervals.begin(), merged_intervals.end(), buffer.ptr()); + int index = next - merged_intervals.begin() - 1; + const auto merged = reinterpret_cast(merged_intervals[index]); + const auto& mapped = mapped_buffers[index]; + device_buffer = + DeviceBuffer(mapped.device_address() + + static_cast(buffer.ptr() - merged), + buffer.size_bytes()); + } else { + ASSIGN_OR_RETURN(device_buffer, Map(buffer, direction)); + mapped_buffers.push_back(device_buffer); + } + + VLOG(3) << StringPrintf( + "Mapped \"%s\" : %s -> 0x%016llx, %zu bytes. Direction=%d", + name_and_buffer.first.c_str(), buffer.ToString().c_str(), + static_cast( // NOLINT(runtime/int) + device_buffer.device_address()), + device_buffer.size_bytes(), direction); + + user_buffers[name_and_buffer.first].push_back(std::move(device_buffer)); + } + } + + cleaner.release(); + return util::OkStatus(); +} + +util::Status DeviceBufferMapper::UnmapMultiple( + std::vector& device_buffers) { + util::Status status; + for (auto& device_buffer : device_buffers) { + status.Update(Unmap(std::move(device_buffer))); + } + return status; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/device_buffer_mapper.h b/driver/device_buffer_mapper.h new file mode 100644 index 0000000..9c5ff83 --- /dev/null +++ b/driver/device_buffer_mapper.h @@ -0,0 +1,175 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DEVICE_BUFFER_MAPPER_H_ +#define DARWINN_DRIVER_DEVICE_BUFFER_MAPPER_H_ + +#include +#include + +#include "api/buffer.h" +#include "driver/device_buffer.h" +#include "driver/memory/address_space.h" +#include "driver/memory/dma_direction.h" +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Thread-unsafe. +// Maps request-specific Buffers to DeviceBuffers, and keeps track of +// DeviceBuffers. These include: input, output, instruction and scratch. +// Note that parameters are mapped and owned by ExecutableReference. +class DeviceBufferMapper { + public: + explicit DeviceBufferMapper(AddressSpace* address_space); + ~DeviceBufferMapper() = default; + + // This class is neither copyable nor movable. + DeviceBufferMapper(const DeviceBufferMapper&) = delete; + DeviceBufferMapper& operator=(const DeviceBufferMapper&) = delete; + + // Unmaps all per-request buffers. It is safe to call this method for cleanup + // even if DeviceBuffers are partially mapped. + util::Status UnmapAll(); + + // Maps given buffers to DeviceBuffers. + util::Status MapInputs(const Buffer::NamedMap& buffers); + util::Status MapOutputs(const Buffer::NamedMap& buffers); + util::Status MapScratch(const Buffer& buffer); + util::Status MapInstructions(const std::vector& buffers); + + // Returns mapped DeviceBuffers. + const DeviceBuffer::NamedMap& GetInputDeviceBuffers() const { + return inputs_; + } + const DeviceBuffer::NamedMap& GetOutputDeviceBuffers() const { + return outputs_; + } + const DeviceBuffer& GetScratchDeviceBuffer() const { return scratch_; } + const std::vector& GetInstructionDeviceBuffers() const { + return instructions_; + } + + // Returns mapped DeviceBuffer for given argument. + const DeviceBuffer& GetInputDeviceBuffer(const std::string& name, + int batch) const { + return inputs_.at(name)[batch]; + } + const DeviceBuffer& GetOutputDeviceBuffer(const std::string& name, + int batch) const { + return outputs_.at(name)[batch]; + } + const DeviceBuffer& GetInstructionDeviceBuffer(int chunk_id) const { + DCHECK_LT(chunk_id, instructions_.size()); + return instructions_[chunk_id]; + } + + private: + // Convenience function that wraps AddressSpace#Map() handling invalid + // buffers. + util::StatusOr Map(const Buffer& buffer, + DmaDirection direction); + + // Convenience function that wraps AddressSpace#UnmapMemory() handling invalid + // buffers. + util::Status Unmap(DeviceBuffer buffer); + + // Helper function to map multiple buffers, merging adjacent buffers. + // - Fills user_buffers with a map of device buffers that directly correspond + // to the passed in buffers. Data parallel elements are represented as + // separate entries, even if the memory is contiguous. These device buffers + // are suitable for use in the instruction linking process. + // - Fills mapped_buffers with the merged list of device buffers that actually + // got mapped. These are the device buffers that need to be unmapped later. + util::Status MapMultiple(const Buffer::NamedMap& buffers, + DmaDirection direction, + /*out*/ DeviceBuffer::NamedMap& user_buffers, + /*out*/ std::vector& mapped_buffers); + + // Helper function to unmap multiple buffers. All passed in buffers will be + // invalidated by this call. + util::Status UnmapMultiple(std::vector& device_buffers); + + // Address space used for mapping. + AddressSpace* const address_space_; + + // Scratch buffer. Could be invalid. + DeviceBuffer scratch_; + + // Input/output buffers. + // input/output[layer_name][batch_id] = DeviceBuffer + DeviceBuffer::NamedMap inputs_; + DeviceBuffer::NamedMap outputs_; + + // Actual mappings that were created, after coalescing adjacent buffers. These + // are the mappings that need to be unmapped at the end of the request. + std::vector input_mappings_; + std::vector output_mappings_; + + // Instruction buffers. + std::vector instructions_; + + // Actual mappings that were created for instructions, after coalescing + // adjacent buffers. + std::vector instruction_mappings_; +}; + +// Holds a mapped device buffer as well as a callback for unmapping. +class MappedDeviceBuffer { + public: + MappedDeviceBuffer() = default; + MappedDeviceBuffer( + const DeviceBuffer& device_buffer, + const std::function& unmapper) + : device_buffer_(device_buffer), + unmap_(std::bind(unmapper, device_buffer)) {} + + ~MappedDeviceBuffer() { + // We should have unmapped the buffer at this moment. + CHECK(!unmap_); + } + + // This type is not copyable; we can't have the same device buffer unmapped + // more than once. + MappedDeviceBuffer(const MappedDeviceBuffer&) = delete; + MappedDeviceBuffer& operator=(const MappedDeviceBuffer&) = delete; + + // This type is movable. + MappedDeviceBuffer(MappedDeviceBuffer&& other) = default; + MappedDeviceBuffer& operator=(MappedDeviceBuffer&& other) = default; + + const DeviceBuffer& device_buffer() const { return device_buffer_; } + + // Unmaps the associated DeviceBuffer using the given unmapper. + util::Status Unmap() { + if (unmap_) RETURN_IF_ERROR(unmap_()); + unmap_ = nullptr; + return util::Status(); // OK. + } + + private: + DeviceBuffer device_buffer_; + std::function unmap_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DEVICE_BUFFER_MAPPER_H_ diff --git a/driver/dma_chunker.cc b/driver/dma_chunker.cc new file mode 100644 index 0000000..cbec6b6 --- /dev/null +++ b/driver/dma_chunker.cc @@ -0,0 +1,91 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/dma_chunker.h" + +#include +#include + +#include "port/logging.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +DeviceBuffer DmaChunker::GetNextChunk() { + const auto curr_offset = GetNextChunkOffset(); + const int remaining_bytes = buffer_.size_bytes() - curr_offset; + VLOG(10) << StringPrintf( + "Completed %zd bytes; Outstanding %zd bytes; Processing next %d bytes", + transferred_bytes_, active_bytes_, remaining_bytes); + + MarkActive(remaining_bytes); + return buffer_.Slice(curr_offset, remaining_bytes); +} + +DeviceBuffer DmaChunker::GetNextChunk(int num_bytes) { + const auto curr_offset = GetNextChunkOffset(); + const int remaining_bytes = buffer_.size_bytes() - curr_offset; + const int transfer_bytes = std::min(remaining_bytes, num_bytes); + VLOG(10) << StringPrintf( + "Completed %zd bytes; Outstanding %zd bytes; Processing next %d bytes", + transferred_bytes_, active_bytes_, transfer_bytes); + + MarkActive(transfer_bytes); + return buffer_.Slice(curr_offset, transfer_bytes); +} + +void DmaChunker::NotifyTransfer(int transferred_bytes) { + transferred_bytes_ += transferred_bytes; + CHECK_GE(active_bytes_, transferred_bytes); + switch (processing_) { + case HardwareProcessing::kCommitted: + active_bytes_ -= transferred_bytes; + break; + + case HardwareProcessing::kBestEffort: + // Active bytes may be partially dropped by HW. Re-chunk them. + active_bytes_ = 0; + break; + } + CHECK_LE(transferred_bytes_, buffer_.size_bytes()); +} + +int DmaChunker::GetNextChunkOffset() const { + switch (processing_) { + case HardwareProcessing::kCommitted: + return transferred_bytes_ + active_bytes_; + + case HardwareProcessing::kBestEffort: + return transferred_bytes_; + } +} + +void DmaChunker::MarkActive(int num_bytes) { + switch (processing_) { + case HardwareProcessing::kCommitted: + active_bytes_ += num_bytes; + return; + + case HardwareProcessing::kBestEffort: + // Previous active bytes are irrelavant as best-effort can drop them. + active_bytes_ = num_bytes; + return; + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/dma_chunker.h b/driver/dma_chunker.h new file mode 100644 index 0000000..0b7bfdb --- /dev/null +++ b/driver/dma_chunker.h @@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DMA_CHUNKER_H_ +#define DARWINN_DRIVER_DMA_CHUNKER_H_ + +#include + +#include "driver/device_buffer.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A class to chunk DMAs into smaller DMAs given hardware constraints. +// +// Hardware can be: +// (1) HardwareProcessing::kCommitted +// Chunk given out from this class will be always processed in full, so +// DmaChunker will give out next chunk from previously given out chunk. +// +// (2) HardwareProcessing::kBestEffort +// Chunk given out from this class will be processed best-effort, and may +// partially fulfilled. Until DmaChunker is notified a completion of transfer +// (could be partial number of bytes), DmaChunker will give out the same chunk. +class DmaChunker { + public: + // Indicates how DMA will be processed in HW. + enum class HardwareProcessing { + // Chunked DMA will be always processed in full by HW. + kCommitted, + + // Chunked DMA will be processed best-effort, and HW may partially perform + // DMA. + kBestEffort, + }; + + DmaChunker(HardwareProcessing processing, const DeviceBuffer& buffer) + : processing_(processing), buffer_(buffer) {} + + // Returns true if there is next DMA chunk. + bool HasNextChunk() const { + return GetNextChunkOffset() < buffer_.size_bytes(); + } + + // Returns next DMA chunk to perform in full. + DeviceBuffer GetNextChunk(); + + // Returns next DMA chunk to perform upto "num_bytes". + DeviceBuffer GetNextChunk(int num_bytes); + + // Notifies that "transferred_bytes" amount of data has been transferred. + void NotifyTransfer(int transferred_bytes); + + // Returns true if transfer is active/completed. + bool IsActive() const { return active_bytes_ > 0; } + bool IsCompleted() const { + return buffer_.size_bytes() == transferred_bytes_; + } + + // Returns total DMA buffer. + const DeviceBuffer& buffer() const { return buffer_; } + + // Returns how many active transfers are out, where each transfer is "bytes". + int GetActiveCounts(int bytes) const { + // Want to calculate CeilOfRatio(active_btyes_, bytes) + const int floor = active_bytes_ / bytes; + return floor + ((floor * bytes) < active_bytes_); + } + + private: + // Returns next chunk offset to transfer. + int GetNextChunkOffset() const; + + // Marks "num_bytes" as actively transferred. + void MarkActive(int num_bytes); + + // Hardware constraints. + const HardwareProcessing processing_; + + // DeviceBuffer underlying DMA. + const DeviceBuffer buffer_; + + // Number of actively transferring bytes. + size_t active_bytes_{0}; + + // Number of transferred bytes. + size_t transferred_bytes_{0}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DMA_CHUNKER_H_ diff --git a/driver/dma_info.cc b/driver/dma_info.cc new file mode 100644 index 0000000..d39303f --- /dev/null +++ b/driver/dma_info.cc @@ -0,0 +1,82 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/dma_info.h" + +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Returns a debugging-friendly string. +std::string ToString(DmaState state) { + switch (state) { + case DmaState::kPending: + return "pending"; + + case DmaState::kActive: + return "active"; + + case DmaState::kCompleted: + return "completed"; + + case DmaState::kError: + return "error"; + } +} + +// Returns a debugging-friendly string. +std::string ToString(const DeviceBuffer& buffer) { + return StringPrintf("device_address = 0x%llx, bytes = %zd", + static_cast(buffer.device_address()), + buffer.size_bytes()); +} + +} // namespace + +std::string DmaInfo::Dump() const { + std::string prefix = StringPrintf("DMA[%d]: ", id_); + switch (type_) { + case DmaDescriptorType::kInstruction: + return prefix + "Instruction: " + ToString(buffer_) + ", " + + ToString(state_); + case DmaDescriptorType::kInputActivation: + return prefix + "Input activation: " + ToString(buffer_) + ", " + + ToString(state_); + case DmaDescriptorType::kParameter: + return prefix + "Parameter: " + ToString(buffer_) + ", " + + ToString(state_); + case DmaDescriptorType::kOutputActivation: + return prefix + "Output activation: " + ToString(buffer_) + ", " + + ToString(state_); + case DmaDescriptorType::kScalarCoreInterrupt0: + return prefix + "SC interrupt 0"; + case DmaDescriptorType::kScalarCoreInterrupt1: + return prefix + "SC interrupt 1"; + case DmaDescriptorType::kScalarCoreInterrupt2: + return prefix + "SC interrupt 2"; + case DmaDescriptorType::kScalarCoreInterrupt3: + return prefix + "SC interrupt 3"; + case DmaDescriptorType::kLocalFence: + return prefix + "Local fence"; + case DmaDescriptorType::kGlobalFence: + return prefix + "Global fence"; + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/dma_info.h b/driver/dma_info.h new file mode 100644 index 0000000..b370a62 --- /dev/null +++ b/driver/dma_info.h @@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DMA_INFO_H_ +#define DARWINN_DRIVER_DMA_INFO_H_ + +#include + +#include "driver/device_buffer.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Possible DMA descriptor types. +enum class DmaDescriptorType { + kInstruction = 0, + kInputActivation = 1, + kParameter = 2, + kOutputActivation = 3, + kScalarCoreInterrupt0 = 4, + kScalarCoreInterrupt1 = 5, + kScalarCoreInterrupt2 = 6, + kScalarCoreInterrupt3 = 7, + + // Fence is not exposed to driver. + // Used to synchronize DMAs local to a Request. + kLocalFence = 8, + + // Used to synchronize DMAs across Requests. + kGlobalFence = 9, +}; + +// Tracks DMA status. +enum class DmaState { + // DMA has not started yet. + kPending, + + // DMA is on-the-fly. + kActive, + + // DMA has completed. + kCompleted, + + // DMA had an error. + kError, +}; + +// DMA information. +class DmaInfo { + public: + DmaInfo(int id, DmaDescriptorType type) : id_(id), type_(type) {} + DmaInfo(int id, DmaDescriptorType type, const DeviceBuffer& buffer) + : id_(id), type_(type), buffer_(buffer) {} + + // Accessors. + int id() const { return id_; } + DmaDescriptorType type() const { return type_; } + const DeviceBuffer& buffer() const { return buffer_; } + + // Returns true if DMA is in given state. + bool IsActive() const { return state_ == DmaState::kActive; } + bool IsCompleted() const { return state_ == DmaState::kCompleted; } + bool IsInError() const { return state_ == DmaState::kError; } + + // Sets to given state. + void MarkActive() { state_ = DmaState::kActive; } + void MarkCompleted() { state_ = DmaState::kCompleted; } + + // Returns debug-friendly information. + std::string Dump() const; + + private: + // ID. + int id_; + + // Type of DMA. + DmaDescriptorType type_; + + // DMA status. + DmaState state_{DmaState::kPending}; + + // Memory to DMA from the device point of view. + DeviceBuffer buffer_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DMA_INFO_H_ diff --git a/driver/dma_info_extractor.cc b/driver/dma_info_extractor.cc new file mode 100644 index 0000000..6d17727 --- /dev/null +++ b/driver/dma_info_extractor.cc @@ -0,0 +1,181 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/dma_info_extractor.h" + +#include "driver/memory/address_utilities.h" +#include "driver/package_registry.h" +#include "executable/executable_generated.h" +#include "port/logging.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +DmaInfoExtractor::DmaInfoExtractor(ExtractorType type, bool overlap_requests) + : type_(type), overlap_requests_(overlap_requests) {} + +std::list DmaInfoExtractor::ExtractDmaInfos( + const ExecutableReference& executable_reference, + const DeviceBufferMapper& buffers) const { + switch (type_) { + case ExtractorType::kInstructionDma: + return ExtractInstructionDmaInfos(buffers); + + case ExtractorType::kDmaHints: + return ExtractDmaHints(executable_reference, buffers); + + case ExtractorType::kFirstInstruction: + return ExtractFirstInstruction(buffers); + } +} + +std::list DmaInfoExtractor::ExtractInstructionDmaInfos( + const DeviceBufferMapper& buffers) const { + std::list dmas; + const auto& instructions = buffers.GetInstructionDeviceBuffers(); + + int id = 0; + for (const auto& buffer : instructions) { + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kInstruction, buffer)); + } + if (!overlap_requests_) { + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kGlobalFence)); + } + return dmas; +} + +std::list DmaInfoExtractor::ExtractDmaHints( + const ExecutableReference& executable_reference, + const DeviceBufferMapper& buffers) const { + CHECK(executable_reference.executable().dma_hints() != nullptr); + const DmaHints& dma_hints = *executable_reference.executable().dma_hints(); + std::list dmas; + int id = 0; + for (const auto& dma_hint : *dma_hints.hints()) { + switch (dma_hint->any_hint_type()) { + case AnyHint_DmaDescriptorHint: { + const auto& descriptor = dma_hint->any_hint_as_DmaDescriptorHint(); + const auto& meta = descriptor->meta(); + switch (meta->desc()) { + case Description_BASE_ADDRESS_INPUT_ACTIVATION: { + const auto& buffer = buffers.GetInputDeviceBuffer( + meta->name()->str(), meta->batch()); + // Input buffers may not be padded, so the DMA may request a small + // amount of data past the end of the input buffer. Double check + // that we don't cross a page boundary, but otherwise allow the + // DMA to read past the end of the buffer. + uint64 last_page_of_buffer = GetPageAddress( + buffer.device_address() + buffer.size_bytes() - 1); + uint64 last_page_of_dma = GetPageAddress( + buffer.device_address() + descriptor->offset_in_bytes() + + descriptor->size_in_bytes() - 1); + CHECK_LE(last_page_of_dma, last_page_of_buffer); + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kInputActivation, + buffer.Slice(descriptor->offset_in_bytes(), + descriptor->size_in_bytes(), + /*allow_overflow=*/true))); + break; + } + + case Description_BASE_ADDRESS_OUTPUT_ACTIVATION: { + const auto& buffer = buffers.GetOutputDeviceBuffer( + meta->name()->str(), meta->batch()); + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kOutputActivation, + buffer.Slice(descriptor->offset_in_bytes(), + descriptor->size_in_bytes()))); + break; + } + + case Description_BASE_ADDRESS_PARAMETER: { + const auto& buffer = + executable_reference.GetParameterDeviceBuffer(); + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kParameter, + buffer.Slice(descriptor->offset_in_bytes(), + descriptor->size_in_bytes()))); + break; + } + + case Description_BASE_ADDRESS_SCRATCH: { + const auto& buffer = buffers.GetScratchDeviceBuffer(); + if (dma_hint->direction() == Direction_INFEED) { + dmas.push_back( + DmaInfo(id++, DmaDescriptorType::kInputActivation, + buffer.Slice(descriptor->offset_in_bytes(), + descriptor->size_in_bytes()))); + } else { + DCHECK_EQ(dma_hint->direction(), Direction_OUTFEED); + dmas.push_back( + DmaInfo(id++, DmaDescriptorType::kOutputActivation, + buffer.Slice(descriptor->offset_in_bytes(), + descriptor->size_in_bytes()))); + } + break; + } + } + break; + } + + case AnyHint_InstructionHint: { + const int chunk_id = + dma_hint->any_hint_as_InstructionHint()->instruction_chunk_index(); + const auto& buffer = buffers.GetInstructionDeviceBuffer(chunk_id); + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kInstruction, buffer)); + break; + } + + case AnyHint_InterruptHint: { + const auto& interrupt = dma_hint->any_hint_as_InterruptHint(); + const DmaDescriptorType type = static_cast( + static_cast(DmaDescriptorType::kScalarCoreInterrupt0) + + static_cast(interrupt->type())); + dmas.push_back(DmaInfo(id++, type)); + break; + } + + case AnyHint_FenceHint: { + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kLocalFence)); + break; + } + + case AnyHint_NONE: + LOG(FATAL) << StringPrintf("Unrecognized hint"); + break; + } + } + + // Add GlobalFence to enforce ordering when hints are not fully deterministic. + if (!dma_hints.fully_deterministic() || !overlap_requests_) { + dmas.push_back(DmaInfo(id++, DmaDescriptorType::kGlobalFence)); + } + + if (VLOG_IS_ON(10)) { + for (const auto& dma : dmas) { + VLOG(10) << dma.Dump(); + } + } + return dmas; +} + +std::list DmaInfoExtractor::ExtractFirstInstruction( + const DeviceBufferMapper& buffers) const { + const auto& instructions = buffers.GetInstructionDeviceBuffers(); + return {DmaInfo(/*id=*/0, DmaDescriptorType::kInstruction, instructions[0]), + DmaInfo(/*id=*/1, DmaDescriptorType::kGlobalFence)}; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/dma_info_extractor.h b/driver/dma_info_extractor.h new file mode 100644 index 0000000..3921f0d --- /dev/null +++ b/driver/dma_info_extractor.h @@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DMA_INFO_EXTRACTOR_H_ +#define DARWINN_DRIVER_DMA_INFO_EXTRACTOR_H_ + +#include + +#include "driver/device_buffer_mapper.h" +#include "driver/dma_info.h" +#include "driver/package_registry.h" +#include "executable/executable_generated.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Extracts DMAs to be performed by driver. +class DmaInfoExtractor { + public: + // Determines how to extract DMA infos for the executable. + enum class ExtractorType { + // Extracts only instruction DMAs for baseline PCIe usecase. + kInstructionDma = 0, + + // Extracts through DMA hints for USB usecase. + kDmaHints = 1, + + // Extracts only first instruction DMA for USB usecase. + kFirstInstruction = 2, + }; + + explicit DmaInfoExtractor(ExtractorType type) + : DmaInfoExtractor(type, true) {} + DmaInfoExtractor(ExtractorType type, bool overlap_requests); + virtual ~DmaInfoExtractor() = default; + + // Extracts a list of DMAs to be performed. + virtual std::list ExtractDmaInfos( + const ExecutableReference& executable_reference, + const DeviceBufferMapper& buffers) const; + + private: + // Extracts instruction DMAs. + std::list ExtractInstructionDmaInfos( + const DeviceBufferMapper& buffers) const; + + // Extracts DMA hints. + std::list ExtractDmaHints( + const ExecutableReference& executable_reference, + const DeviceBufferMapper& buffers) const; + + // Extracts first instruction DMA. + std::list ExtractFirstInstruction( + const DeviceBufferMapper& buffers) const; + + // Extractor type. + const ExtractorType type_; + + // True if requests can be overlapped. Should be set to false just for + // debugging. + const bool overlap_requests_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DMA_INFO_EXTRACTOR_H_ diff --git a/driver/dma_scheduler.h b/driver/dma_scheduler.h new file mode 100644 index 0000000..9ca94c9 --- /dev/null +++ b/driver/dma_scheduler.h @@ -0,0 +1,104 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DMA_SCHEDULER_H_ +#define DARWINN_DRIVER_DMA_SCHEDULER_H_ + +#include +#include + +#include "api/driver.h" +#include "driver/dma_info.h" +#include "driver/tpu_request.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Manages the processing order of DMAs from DarwiNN Request, and also keeps +// track of the requests. All implementation of DMA scheduler has to be +// thread-safe. +// +// Example usage: +// DmaScheduler scheduler; +// scheduler.Submit(request0); +// scheduler.Submit(request1); +// ... +// const auto* dma = scheduler.GetNextDma(); +// // Handle DMA. +// if DMA is completed: +// scheduler.NotifyDmaCompletion(dma); +// ... +// // when Request is complete +// scheduler.NotifyRequestCompletion(); +class DmaScheduler { + public: + DmaScheduler() = default; + + // This class is neither copyable nor movable. + DmaScheduler(const DmaScheduler&) = delete; + DmaScheduler& operator=(const DmaScheduler&) = delete; + + virtual ~DmaScheduler() = default; + + // Opens/closes DMA scheduler. + virtual util::Status Open() = 0; + virtual util::Status Close(api::Driver::ClosingMode mode) = 0; + + // Submits a request for execution on DarwiNN. + virtual util::Status Submit(std::shared_ptr request) = 0; + + // Returns next DMA type to be performed. Returns kLocalFence if there is no + // next DMA. + virtual util::StatusOr PeekNextDma() const = 0; + + // Returns DMA to perform. If there is no DMA to perform, returns nullptr. + // Target of pointers are internally maintained. + // DmaScheduler::NotifyDmaCompletion is a contract that given pointer is no + // longer used by external entity. + virtual util::StatusOr GetNextDma() = 0; + + // Notifies that DMA for given "dma_info" has completed. Returns an error if + // given "dma_info" cannot be completed. + virtual util::Status NotifyDmaCompletion(DmaInfo* dma_info) = 0; + + // Notifies when request has been completed, and performs any necessary + // cleanups. + virtual util::Status NotifyRequestCompletion() = 0; + + // Cancels all the pending requests that has not been submitted to DarwiNN + // device yet. + virtual util::Status CancelPendingRequests() = 0; + + // Waits until active requests are done. + virtual util::Status WaitActiveRequests() = 0; + + // Returns true if there is no DMAs to schedule. + virtual bool IsEmpty() const = 0; + + // Returns the upper bound on number of TPU cycles remaining to complete all + // scheduled tasks. + virtual int64 MaxRemainingCycles() const = 0; + + // Returns the oldest submitted request that's still active. + virtual util::StatusOr> GetOldestActiveRequest() + const = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DMA_SCHEDULER_H_ diff --git a/driver/driver.cc b/driver/driver.cc new file mode 100644 index 0000000..a3b377f --- /dev/null +++ b/driver/driver.cc @@ -0,0 +1,716 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "api/execution_context_interface.h" +#include "api/package_reference.h" +#include "api/request.h" +#include "driver/package_registry.h" +#include "driver/request.h" +#include "driver/tpu_request.h" +#include "executable/executable_generated.h" +#include "port/blocking_counter.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/math_util.h" +#include "port/ptr_util.h" +#include "port/shared_mutex.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { + +using api::ExecutionContextInterface; + +} // namespace + +Driver::Driver(api::Chip chip, + std::unique_ptr executable_registry, + const api::DriverOptions& driver_options, + std::unique_ptr time_stamper) + : chip_(chip), + executable_registry_(std::move(executable_registry)), + time_stamper_(std::move(time_stamper)), + current_parameter_caching_token_(0), + debug_mode_(false), + max_scheduled_work_ns_(driver_options.max_scheduled_work_ns()) { + // Use the default_telemeter by default. + telemeter_interface_ = &default_telemeter_; + + operational_settings_.tpu_frequency_hz = driver_options.tpu_frequency_hz(); + operational_settings_.host_to_tpu_bps = driver_options.host_to_tpu_bps(); + + scheduler_thread_ = std::thread([this]() { SchedulerWorker(); }); +} + +Driver::~Driver() { + { + StdMutexLock scheduler_lock(&scheduler_mutex_); + destructing_ = true; + scheduler_wakeup_.notify_one(); + } + + if (scheduler_thread_.joinable()) { + scheduler_thread_.join(); + } +} + +std::string Driver::BadStateMessage(State expected_state) const { + return StringPrintf("Bad driver state. expected=%d, actual=%d.", + expected_state, state_); +} + +util::Status Driver::SetState(State next_state) { + switch (state_) { + case kOpen: + if (next_state == kClosing) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosing: + if (next_state == kClosed) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosed: + if (next_state == kOpen) { + state_ = next_state; + return util::Status(); // OK + } + break; + } + + // Illegal state transition. + return util::FailedPreconditionError(StringPrintf( + "Invalid state transition. current=%d, next=%d.", state_, next_state)); +} + +bool Driver::IsOpen() const { + ReaderMutexLock state_reader_lock(&state_mutex_); + return state_ == kOpen; +} + +bool Driver::IsError() const { return in_error_; } + +util::Status Driver::Open(bool debug_mode, bool context_lost) { + WriterMutexLock state_writer_lock(&state_mutex_); + if (num_clients_ > 0) { + if (context_lost) { + return util::InvalidArgumentError( + "context_lost was set at open() yet there were others holding the " + "driver open."); + } + + num_clients_++; + return util::Status(); // OK + } + + if (state_ != kClosed) { + return util::FailedPreconditionError(BadStateMessage(kClosed)); + } + + if (context_lost) { + executable_registry_->ResetParametersLoaded(); + } + + debug_mode_ = debug_mode; + RETURN_IF_ERROR(DoOpen(debug_mode)); + num_clients_++; + + // All good. Move state to open. + RETURN_IF_ERROR(SetState(kOpen)); + + return util::Status(); // OK. +} + +namespace { + +// Computes maximum execution time in ms by taking the ceiling of +// cycle / frequency in KHz. +int64 ComputeMETinMs(int64 cycles, int64 frequency) { + constexpr int64 kKilo = 1000LL; + if (cycles > 0 && frequency > 0) { + return 1 + (cycles - 1) / (frequency / kKilo); + } else { + return 0; + } +} + +} // namespace + +util::Status Driver::UpdateInitialTiming( + const api::PackageReference* api_package_reference) { + StdMutexLock lock(&submit_mutex_); + + const PackageReference* driver_package_reference = + static_cast(api_package_reference); + const auto* executable_reference = + driver_package_reference->MainExecutableReference(); + + // Don't bother calling the driver's SetExecutableTiming if it doesn't even + // support real-time mode, or if the driver's operating frequency is not set. + if (!HasImplementedRealtimeMode() || + operational_settings_.tpu_frequency_hz <= 0) { + return util::OkStatus(); + } + + // Producing initial guess for estimated execution time. + // More precise execution time can be obtained from real execution timing + // statistics. At this point, understating operating frequency is ok as we + // need a conservative estimation. + if (executable_reference->EstimatedCycles() > 0) { + api::Timing timing; + timing.max_execution_time_ms = static_cast( + ComputeMETinMs(executable_reference->EstimatedCycles(), + operational_settings_.tpu_frequency_hz)); + // This initial timing setting is set regardless whether the underlying + // driver supports real-time mode or not, thus is best effort. + return SetExecutableTiming(api_package_reference, timing); + } else { + // The executable doesn't carry estimated cycles information. + return util::OkStatus(); + } +} + +util::StatusOr Driver::RegisterExecutableFile( + const std::string& executable_filename) { + TRACE_SCOPE("Driver::RegisterExecutableFile"); + ASSIGN_OR_RETURN(auto* registered_package, + executable_registry_->RegisterFile(executable_filename)); + RETURN_IF_ERROR(UpdateInitialTiming(registered_package)); + return registered_package; +} + +util::StatusOr +Driver::RegisterExecutableSerialized(const std::string& executable_content) { + TRACE_SCOPE("Driver::RegisterExecutableSerialized"); + ASSIGN_OR_RETURN( + auto* registered_package, + executable_registry_->RegisterSerialized(executable_content)); + RETURN_IF_ERROR(UpdateInitialTiming(registered_package)); + return registered_package; +} + +util::StatusOr +Driver::RegisterExecutableSerialized(const char* executable_content, + size_t length) { + TRACE_SCOPE("Driver::RegisterExecutableSerialized"); + ASSIGN_OR_RETURN( + auto* registered_package, + executable_registry_->RegisterSerialized(executable_content, length)); + RETURN_IF_ERROR(UpdateInitialTiming(registered_package)); + return registered_package; +} + +// TODO Keeping parameters mapped for the entire time driver is +// open can lead to OOM even if we have enough memory for one request. +util::Status Driver::MapParameters(PackageReference& package_ref) { + TRACE_SCOPE("Driver::MapParameters"); + + // If this is the first time we are mapping parameters and the parameters are + // supposed to reside in the on-chip DRAM, we should transfer them first. + + for (auto* driver_executable_ref : package_ref.AllExecutableReferences()) { + RETURN_IF_ERROR(driver_executable_ref->PrepareParameters()); + const Buffer& buffer = driver_executable_ref->parameters(); + + // TODO Investigate if we need to optimize cache flushing here. + ASSIGN_OR_RETURN(MappedDeviceBuffer mapped_device_buffer, + DoMapBuffer(buffer, DmaDirection::kToDevice)); + + const DeviceBuffer& device_buffer = mapped_device_buffer.device_buffer(); + VLOG(3) << absl::StrFormat("Mapped params : %s -> 0x%016llx, %zu bytes.", + buffer.ToString().c_str(), + device_buffer.device_address(), + device_buffer.size_bytes()); + RETURN_IF_ERROR(driver_executable_ref->SetMappedParameters( + std::move(mapped_device_buffer))); + } + + return util::OkStatus(); +} + +Buffer Driver::MakeBuffer(size_t size_bytes) const { + return DoMakeBuffer(size_bytes); +} + +util::Status Driver::UnregisterExecutable( + const api::PackageReference* executable_ref) { + ReaderMutexLock state_reader_lock(&state_mutex_); + + // Remove per-executable timing information from real-time scheduler. + if (HasImplementedRealtimeMode()) { + RETURN_IF_ERROR(RemoveExecutableTiming(executable_ref)); + } + + // TODO : should defer unregistering if there are pending + // requests. + return executable_registry_->Unregister(executable_ref); +} + +util::StatusOr> Driver::CreateRequest( + const api::PackageReference* api_package_ref) { + if (api_package_ref == nullptr) { + return util::InvalidArgumentError("Package reference is null."); + } + + const auto* package_ref = + static_cast(api_package_ref); + return {std::make_shared( + next_id_.fetch_add(1, std::memory_order_relaxed), *package_ref, + *time_stamper_)}; +} + +util::Status Driver::Submit(std::shared_ptr api_request, + api::Request::Done done_callback) { + TRACE_SCOPE("Driver::Submit"); + ReaderMutexLock state_reader_lock(&state_mutex_); + StdMutexLock submit_lock(&submit_mutex_); + + if (state_ != kOpen) { + return util::UnavailableError(BadStateMessage(kOpen)); + } + + auto request = std::static_pointer_cast(api_request); + RETURN_IF_ERROR(request->SetDone(std::move(done_callback))); + RETURN_IF_ERROR(request->Prepare()); + RETURN_IF_ERROR(CheckLatencyTolerance(request)); + + if (request->GetPriority() == 0) { + VLOG(4) << StringPrintf("Request [%d]: Submitting P0 request immediately.", + request->id()); + ASSIGN_OR_RETURN(auto remaining_tpu_requests, + request->RemainingTpuRequestCount()); + for (int i = 0; i < remaining_tpu_requests; ++i) { + RETURN_IF_ERROR(SubmitInferenceRequest(request)); + } + } else { + VLOG(4) << StringPrintf( + "Request [%d]: Pushing P%d request to its priority queue.", + request->id(), request->GetPriority()); + pending_requests_[request->GetPriority()].push(std::move(request)); + RETURN_IF_ERROR(TrySchedulePendingRequests()); + } + + return util::OkStatus(); +} + +util::Status Driver::CheckLatencyTolerance( + const std::shared_ptr& request) { + TRACE_SCOPE("Driver::CheckLatencyTolerance"); + const auto& package_ref = request->GetPackageReference(); + if (package_ref.LatencyToleranceMs() <= 0) { + // No latency requirement set. + return util::OkStatus(); + } + + if (request->GetPriority() > 0) { + return util::InvalidArgumentError( + "Latency tolerance can only be set for P0 requests."); + } + + ASSIGN_OR_RETURN(auto tpu_request_count, request->RemainingTpuRequestCount()); + int64 estimated_cycles = + tpu_request_count * + package_ref.MainExecutableReference()->EstimatedCycles(); + + ASSIGN_OR_RETURN(bool needs_parameter_caching, + NeedsParameterCaching(request)); + if (needs_parameter_caching) { + estimated_cycles += + package_ref.ParameterCachingExecutableReference()->EstimatedCycles(); + } + + estimated_cycles += MaxRemainingCycles(); + + int64 estimated_time_ms = + ComputeMETinMs(estimated_cycles, operational_settings_.tpu_frequency_hz); + if (estimated_time_ms > package_ref.LatencyToleranceMs()) { + return util::DeadlineExceededError(absl::StrFormat( + "Estimated execution time (%lld ms) exceeds max tolerance (%lld ms).", + estimated_time_ms, package_ref.LatencyToleranceMs())); + } + + return util::OkStatus(); +} + +util::Status Driver::SubmitInferenceRequest(std::shared_ptr request) { + TRACE_SCOPE("Driver::SubmitInferenceRequest"); + const auto& package_ref = request->GetPackageReference(); + ASSIGN_OR_RETURN(auto parameters_mapped, package_ref.ParametersMapped()); + if (!parameters_mapped) { + // TODO Remove the const casts. + VLOG(5) << StringPrintf("Request [%d]: Need to map parameters.", + request->id()); + RETURN_IF_ERROR(MapParameters(const_cast(package_ref))); + } + + const auto& main_ref = request->MainExecutableReference(); + if (main_ref.ParameterCachingToken() == 0 || + main_ref.ParameterCachingToken() != current_parameter_caching_token_) { + ResetCachedParameters(); + } + + ASSIGN_OR_RETURN(bool needs_parameter_caching, + NeedsParameterCaching(request)); + if (needs_parameter_caching) { + VLOG(5) << StringPrintf("Request [%d]: Need to do parameter-caching.", + request->id()); + RETURN_IF_ERROR(SubmitParameterCachingRequest(request)); + } + + ASSIGN_OR_RETURN(auto tpu_request, + DoCreateRequest(request, &request->MainExecutableReference(), + TpuRequest::RequestType::INFERENCE)); + RETURN_IF_ERROR(request->PrepareTpuRequest(tpu_request)); + RETURN_IF_ERROR(DoSubmit(std::move(tpu_request))); + request->NotifySubmission(TpuRequest::RequestType::INFERENCE); + + return util::OkStatus(); +} + +util::StatusOr Driver::NeedsParameterCaching( + const std::shared_ptr& request) const { + const auto& package_ref = request->GetPackageReference(); + if (!package_ref.ParameterCachingEnabled()) { + return false; + } + + const auto& parameter_caching_ref = + package_ref.ParameterCachingExecutableReference(); + if (parameter_caching_ref->ParameterCachingToken() == 0) { + return util::InternalError("Parameter caching tag is not set."); + } + + return currently_cached_refs_.find(parameter_caching_ref) == + currently_cached_refs_.end(); +} + +util::Status Driver::SubmitParameterCachingRequest( + const std::shared_ptr& request) { + TRACE_SCOPE("Driver::SubmitParameterCachingRequest"); + auto parameter_caching_ref = + request->GetPackageReference().ParameterCachingExecutableReference(); + + current_parameter_caching_token_ = + parameter_caching_ref->ParameterCachingToken(); + currently_cached_refs_.insert(parameter_caching_ref); + + ASSIGN_OR_RETURN(auto tpu_request, + DoCreateRequest(request, parameter_caching_ref, + TpuRequest::RequestType::PARAMETER_CACHING)); + RETURN_IF_ERROR(tpu_request->SetDone([](int, const util::Status&) {})); + RETURN_IF_ERROR(DoSubmit(std::move(tpu_request))); + request->NotifySubmission(TpuRequest::RequestType::PARAMETER_CACHING); + + return util::OkStatus(); +} + +void Driver::ResetCachedParameters() { + current_parameter_caching_token_ = 0; + currently_cached_refs_.clear(); +} + +void Driver::SchedulerWorker() { + while (true) { + { + StdCondMutexLock lock(&scheduler_mutex_); + while (!schedule_more_requests_ && !destructing_) { + scheduler_wakeup_.wait(lock); + } + + if (destructing_) { + return; + } + + schedule_more_requests_ = false; + } + + ReaderMutexLock state_reader_lock(&state_mutex_); + StdMutexLock submit_lock(&submit_mutex_); + // TODO Improve handling of this error. + CHECK_OK(TrySchedulePendingRequests()); + } +} + +void Driver::HandleTpuRequestCompletion() { + StdMutexLock lock(&scheduler_mutex_); + schedule_more_requests_ = true; + scheduler_wakeup_.notify_one(); +} + +util::Status Driver::TrySchedulePendingRequests() { + for (auto& priority_and_queue : pending_requests_) { + auto& request_queue = priority_and_queue.second; + + while (!request_queue.empty()) { + ASSIGN_OR_RETURN(bool can_schedule, + CanScheduleTpuRequest(request_queue.front())); + if (!can_schedule) { + VLOG(5) << absl::StrFormat( + "Already have %lld cycles in scheduler, no need to schedule more " + "work.", + MaxRemainingCycles()); + return util::OkStatus(); + } + + auto request = request_queue.front(); + VLOG(5) << absl::StrFormat( + "Request [%d]: Scheduling one more TPU request that takes %lld " + "cycles.", + request->id(), request->EstimatedCyclesPerInference()); + + RETURN_IF_ERROR(SubmitInferenceRequest(request)); + + ASSIGN_OR_RETURN(auto remaining_tpu_requests, + request->RemainingTpuRequestCount()); + if (remaining_tpu_requests == 0) { + VLOG(5) << StringPrintf( + "Request [%d]: All TPU requests are now submitted.", request->id()); + request_queue.pop(); + } + } + } + + return util::OkStatus(); +} + +util::StatusOr Driver::CanScheduleTpuRequest( + const std::shared_ptr& request) { + if (request->GetPriority() == 0) { + return util::InvalidArgumentError( + "P0 requests should be immediately scheduled."); + } + + if (max_scheduled_work_ns_ < 0) { + VLOG(7) << StringPrintf( + "max_scheduled_work_ns=%0.f, all requests are scheduled immediately.", + max_scheduled_work_ns_); + return true; + } + + int64 remaining_cycles = MaxRemainingCycles(); + if (remaining_cycles == 0) { + VLOG(7) << "Nothing is in the scheduler, submit one TPU request no matter " + "what."; + return true; + } + + int64 max_cycles_to_schedule = + static_cast( + (max_scheduled_work_ns_ * + static_cast(operational_settings_.tpu_frequency_hz)) / + 1e9) - + remaining_cycles; + + int64 total_cycles = request->EstimatedCyclesPerInference(); + ASSIGN_OR_RETURN(auto needs_parameter_caching, + NeedsParameterCaching(request)); + if (needs_parameter_caching) { + total_cycles += request->GetPackageReference() + .ParameterCachingExecutableReference() + ->EstimatedCycles(); + } + + VLOG(7) << absl::StrFormat( + "Request [%d]: Total cycles needed for scheduling a new inference: %lld, " + "%lld available.", + request->id(), total_cycles, max_cycles_to_schedule); + return (max_cycles_to_schedule >= total_cycles); +} + +util::Status Driver::CancelAllPendingRequests() { + StdMutexLock submit_lock(&submit_mutex_); + + for (auto& priority_and_queue : pending_requests_) { + auto& request_queue = priority_and_queue.second; + + while (!request_queue.empty()) { + auto request = request_queue.front(); + ASSIGN_OR_RETURN(auto remaining_tpu_requests, + request->RemainingTpuRequestCount()); + VLOG(4) << StringPrintf( + "Request [%d]: Cancelling %d remaining TPU requests.", request->id(), + remaining_tpu_requests); + + RETURN_IF_ERROR(request->HandleTpuRequestsDone( + util::CancelledError("Request cancelled."), remaining_tpu_requests)); + request_queue.pop(); + } + } + + return util::OkStatus(); +} + +util::Status Driver::Execute(std::shared_ptr request) { + BlockingCounter counter(1); + util::Status final_status; + + auto done_callback = [&counter, &final_status](int id, util::Status status) { + final_status = std::move(status); + counter.DecrementCount(); + }; + + // Submit asynchronously and wait. + RETURN_IF_ERROR(Submit(std::move(request), std::move(done_callback))); + + counter.Wait(); + + return final_status; +} + +util::Status Driver::Execute( + const std::vector>& requests) { + BlockingCounter counter(requests.size()); + std::mutex status_mutex; + util::Status final_status; + + auto done_callback = [&counter, &final_status, &status_mutex]( + int id, util::Status status) { + StdMutexLock status_lock(&status_mutex); + final_status.Update(status); + counter.DecrementCount(); + }; + + // Submit asynchronously and wait. + for (auto request : requests) { + RETURN_IF_ERROR(Submit(std::move(request), done_callback)); + } + + counter.Wait(); + return final_status; +} + +util::Status Driver::Cancel(std::shared_ptr request) { + return util::UnimplementedError("Unimplemented."); +} + +util::Status Driver::CancelAllRequests() { + return util::UnimplementedError("Unimplemented."); +} + +util::Status Driver::Close(api::Driver::ClosingMode mode) { + WriterMutexLock state_writer_lock(&state_mutex_); + + if (num_clients_ > 1) { + num_clients_--; + return util::Status(); // OK + } + + if (state_ != kOpen) { + return util::FailedPreconditionError(BadStateMessage(kOpen)); + } + + // Note our intention to close. + RETURN_IF_ERROR(SetState(kClosing)); + + // Before starting shutdown process in the lower layers of the stack, we + // need to cancel all pending requests in the priority queue. + RETURN_IF_ERROR(CancelAllPendingRequests()); + + // If we are not in a rush, just clear the pending requests and let the ones + // that have already started DMAing finish. If ASAP is enabled, we can skip + // this step and a full cleanup of queues happens in DoClose. + if (mode == api::Driver::ClosingMode::kGraceful) { + RETURN_IF_ERROR(DoCancelAndWaitRequests(in_error_)); + } + + // Since chip is getting reset, anything cachedon SRAM will be wiped. + { + StdMutexLock submit_lock(&submit_mutex_); + ResetCachedParameters(); + } + + // Actually close. + RETURN_IF_ERROR(DoClose(in_error_, mode)); + + num_clients_--; + return SetState(kClosed); +} + +void Driver::SetFatalErrorCallback(FatalErrorCallback callback) { + fatal_error_callback_ = std::move(callback); +} + +void Driver::SetThermalWarningCallback(ThermalWarningCallback callback) { + thermal_warning_callback_ = std::move(callback); +} + +void Driver::NotifyFatalError(const util::Status& status) { + // Set error state. + bool was_in_error = std::atomic_exchange(&in_error_, true); + if (!was_in_error) { + // Notify Error only the first time the fatal error is triggered. + // TODO: Issue this is in a new detached thread to decouple + // itself from other driver contexts. + if (fatal_error_callback_) { + fatal_error_callback_(status); + } + } +} + +void Driver::HandleWatchdogTimeout() { + LOG(ERROR) << "Watchdog timed out. Collecting runtime metrics."; + auto status_or_request = GetOldestActiveRequest(); + if (!status_or_request.ok()) { + // TODO: Log metric even if TpuRequest is not found. + LOG(ERROR) + << "No active request during watchdog timeout. Unable to log metrics."; + } else { + ExecutionContextInterface* context = status_or_request.ValueOrDie() + ->executable_reference() + .GetPackageReference() + .GetExecutionContextInterface(); + GetTelemeterInterface()->LogWatchdogTimeout(*context); + } + + LOG(ERROR) << "Watchdog activated, resetting TPU."; + CHECK_OK(Close(api::Driver::ClosingMode::kAsap)); + CHECK_OK(Open(debug_mode_)); +} + +util::Status Driver::SetExecutableTiming( + const api::PackageReference* executable, const api::Timing& timing) { + return DoSetExecutableTiming( + static_cast(executable) + ->MainExecutableReference(), + timing); +} + +void Driver::UpdateOperationalSettings(const OperationalSettings& settings) { + StdMutexLock lock(&submit_mutex_); + operational_settings_ = settings; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver.h b/driver/driver.h new file mode 100644 index 0000000..8dccb15 --- /dev/null +++ b/driver/driver.h @@ -0,0 +1,391 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DRIVER_H_ +#define DARWINN_DRIVER_DRIVER_H_ + +#include +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "api/buffer.h" +#include "api/chip.h" +#include "api/driver.h" +#include "api/package_reference.h" +#include "api/request.h" +#include "api/telemeter_interface.h" +#include "driver/default_telemeter.h" +#include "driver/device_buffer_mapper.h" +#include "driver/memory/dma_direction.h" +#include "driver/package_registry.h" +#include "driver/request.h" +#include "driver/time_stamper/time_stamper.h" +#include "driver/tpu_request.h" +#include "executable/executable_generated.h" +#include "port/integral_types.h" +#include "port/shared_mutex.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Base driver implementation. +class Driver : public api::Driver { + public: + ~Driver() override; + + bool IsOpen() const override; + + bool IsError() const override; + + util::Status Open(bool debug_mode = false, bool context_lost = false) + LOCKS_EXCLUDED(state_mutex_) override; + + util::StatusOr RegisterExecutableFile( + const std::string& executable_filename) override; + + util::StatusOr RegisterExecutableSerialized( + const std::string& executable_content) override; + + util::StatusOr RegisterExecutableSerialized( + const char* executable_content, size_t length) override; + + util::Status UnregisterExecutable(const api::PackageReference* executable_ref) + LOCKS_EXCLUDED(state_mutex_) override; + + util::StatusOr> CreateRequest( + const api::PackageReference*) override; + + // TODO If we end up spliting driver::Driver to 2 layers, this + // method can go up a layer. + util::Status Submit(std::shared_ptr request, + api::Request::Done done_callback) + LOCKS_EXCLUDED(state_mutex_, submit_mutex_) override; + + util::Status Execute(std::shared_ptr request) + LOCKS_EXCLUDED(state_mutex_, submit_mutex_) override; + + util::Status Execute( + const std::vector>& requests) + LOCKS_EXCLUDED(state_mutex_, submit_mutex_) override; + + util::Status Cancel(std::shared_ptr request) + LOCKS_EXCLUDED(state_mutex_) override; + + util::Status CancelAllRequests() LOCKS_EXCLUDED(state_mutex_) override; + + util::Status Close(api::Driver::ClosingMode mode) + LOCKS_EXCLUDED(state_mutex_) override; + + void SetFatalErrorCallback(FatalErrorCallback callback) override; + + void SetThermalWarningCallback(ThermalWarningCallback callback) override; + + Buffer MakeBuffer(size_t size_bytes) const override; + + util::Status SetRealtimeMode(bool on) override { + return DoSetRealtimeMode(on); + } + + util::Status SetExecutableTiming(const api::PackageReference* executable, + const api::Timing& timing) override; + + util::Status RemoveExecutableTiming(const api::PackageReference* executable) { + return DoRemoveExecutableTiming( + static_cast(executable) + ->MainExecutableReference()); + } + + util::Status SetExecutionPreference(const api::PackageReference* package, + ExecutionPreference preference) override { + return util::OkStatus(); + } + + void SetTelemeterInterface( + api::TelemeterInterface* telemeter_interface) override { + telemeter_interface_ = telemeter_interface; + }; + + void UpdateOperationalSettings(const OperationalSettings& settings) + LOCKS_EXCLUDED(submit_mutex_) override; + + protected: + Driver(api::Chip chip, std::unique_ptr executable_registry, + const api::DriverOptions& driver_options, + std::unique_ptr timestamper); + + // The base driver implementation does the necessary state checks and + // validations before issuing the following calls that are implemented by the + // derived class. + + virtual util::Status DoOpen(bool debug_mode) + EXCLUSIVE_LOCKS_REQUIRED(state_mutex_) = 0; + + virtual util::Status DoClose(bool in_error, api::Driver::ClosingMode mode) + EXCLUSIVE_LOCKS_REQUIRED(state_mutex_) = 0; + + // Cancels pending requests and waits for active requests to finish. + virtual util::Status DoCancelAndWaitRequests(bool in_error) + SHARED_LOCKS_REQUIRED(state_mutex_) = 0; + + virtual util::StatusOr DoMapBuffer(const Buffer& buffer, + DmaDirection direction) + SHARED_LOCKS_REQUIRED(state_mutex_) = 0; + + virtual util::StatusOr> DoCreateRequest( + const std::shared_ptr parent_request, + const ExecutableReference* executable, TpuRequest::RequestType type) + SHARED_LOCKS_REQUIRED(state_mutex_) = 0; + + virtual util::Status DoSetExecutableTiming( + const ExecutableReference* executable, const api::Timing& timing) = 0; + + virtual util::Status DoRemoveExecutableTiming( + const ExecutableReference* executable) { + return util::FailedPreconditionError("Unsupported operation"); + } + + // TODO by just using RT scheduler everywhere, we can avoid the + // complexity of having a capability query here. + virtual bool HasImplementedRealtimeMode() const { return false; } + + virtual util::Status DoSetRealtimeMode(bool on) = 0; + + virtual util::Status DoSubmit(std::shared_ptr request) + + SHARED_LOCKS_REQUIRED(state_mutex_) = 0; + + virtual Buffer DoMakeBuffer(size_t size_bytes) const = 0; + + // Returns the upper bound estimation of driver on the number of cycles of + // work remaining on the device. + virtual int64 MaxRemainingCycles() const = 0; + + // Notifies that the driver / device has entered an error state. + void NotifyFatalError(const util::Status& status); + + // Unregisters all the currently registered models. + util::Status UnregisterAll() { return executable_registry_->UnregisterAll(); } + + // Unmaps all mapped parameters. This method typically needs to get called + // before closing the MMU mapper. + util::Status UnmapAllParameters() { + return executable_registry_->UnmapAllParameters(); + } + + // Handler for when TPU watchdog expires. This signals an unexpected state in + // TPU. + void HandleWatchdogTimeout(); + + // Gets called when a single TpuRequest has finished execution on the device. + // This needs to be called in all sub-classes of Driver. It should be called + // after MaxRemainingCycles is updated. + void HandleTpuRequestCompletion(); + + // Get the telemeter interface pointer. + api::TelemeterInterface* GetTelemeterInterface() { + return telemeter_interface_; + } + + // Returns the oldest submitted request that's still active. + virtual util::StatusOr> + GetOldestActiveRequest() const = 0; + + private: + // Driver state. Transitions : + // kClosed -> kOpen -> kClosing -> kClosed. + enum State { + kOpen, // Driver is Open. + kClosing, // Driver is Closing. + kClosed, // Driver is Closed. (Initial state.) + }; + + // Attempts a state transition to the given state. + util::Status SetState(State next_state) + EXCLUSIVE_LOCKS_REQUIRED(state_mutex_); + +// Generate string to display for bad driver state errors. + std::string BadStateMessage(State expected_state) const + SHARED_LOCKS_REQUIRED(state_mutex_); + + // Internal helper for mapping parameters. + util::Status MapParameters(PackageReference& package_ref) + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Prepares and submits a single inference TpuRequest from the provided + // request. It returns an error if there are no remaining TpuRequests to be + // submitted. + util::Status SubmitInferenceRequest(std::shared_ptr request) + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Reset the state of cached parameters. This does not do anything to TPU + // memory, only invalidates the cache state in driver which then results in + // reloading parameters if needed. + void ResetCachedParameters() SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Checks if we need to load to-be-cached parameters to the TPU. + util::StatusOr NeedsParameterCaching( + const std::shared_ptr& request) const + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Submits a parameter caching request and updates the records. + util::Status SubmitParameterCachingRequest( + const std::shared_ptr& request) + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Schedules pending requests (if any) up to the limit we are allowed to have + // tasks pending in the DMA scheduler. It returns OK status if there are no + // more requests to be scheduled. It returns an error if there are any errors + // in submitting requests. + util::Status TrySchedulePendingRequests() SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // If a request is for a package with specified latency tolerance, it returns + // a deadline_exceeded error if driver cannot guarantee that it finishes the + // request in less than the tolerable latency. + util::Status CheckLatencyTolerance(const std::shared_ptr& request) + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Cleans up the priority queues by cancelling all pending requests. + util::Status CancelAllPendingRequests() EXCLUSIVE_LOCKS_REQUIRED(state_mutex_) + LOCKS_EXCLUDED(submit_mutex_); + + // Returns true if we can schedule one more inference for the provided request + // given the current state of DMA scheduler, how long it takes for this + // request on TPU and what our threshold for keeping the pipeline busy is. + // This function should not be called for P0 requests. It always returns true + // If there is no more work in DMA scheduler. + util::StatusOr CanScheduleTpuRequest( + const std::shared_ptr& request) + SHARED_LOCKS_REQUIRED(state_mutex_) + EXCLUSIVE_LOCKS_REQUIRED(submit_mutex_); + + // Updates scheduler with static timing estimation from registered executable. + util::Status UpdateInitialTiming( + const api::PackageReference* api_package_reference) + LOCKS_EXCLUDED(submit_mutex_); + + // Runs the scheduler thread. + void SchedulerWorker(); + + // Maintains integrity of the driver state. + mutable SharedMutex state_mutex_; + + // Guarantees that multiple requests will be submitted in the order provided. + // NOTE: state_mutex_ cannot be acquired after submit_mutex_ is locked. + mutable std::mutex submit_mutex_; + + // Counts the number of clients that opened this driver. + int num_clients_ GUARDED_BY(state_mutex_){0}; + + // Driver state. + State state_ GUARDED_BY(state_mutex_){kClosed}; + + // Chip that this driver controls. + const api::Chip chip_; + + // Executable registry. Null, when device is in closed state. + std::unique_ptr executable_registry_; + + // Driver clock for timestamp reporting + std::unique_ptr time_stamper_; + + // Registered fatal Error Callback. + FatalErrorCallback fatal_error_callback_; + + // Registered thermal warning Callback. + ThermalWarningCallback thermal_warning_callback_; + + // True, if device is in error state. + std::atomic in_error_{false}; + + // The currently active parameter-caching token. This token determines if a + // new submission will require reloading cached parameters in TPU SRAM. + uint64 current_parameter_caching_token_ GUARDED_BY(submit_mutex_); + + // A set of parameter-caching ExecutableReferences that shows if that model + // has already cached its parameters on TPU SRAM, and the cache is still + // valid. + std::unordered_set currently_cached_refs_ + GUARDED_BY(submit_mutex_); + + // Specifies if the driver is currently open in debug mode. + bool debug_mode_; + + // A simple ID generator for requests. + std::atomic next_id_{0}; + + // Current operational settings of the driver. Protected by submit_mutex to + // avoid undefined behavior when it changes while an inference is being + // submitted. + OperationalSettings operational_settings_ GUARDED_BY(submit_mutex_); + + // The maximum amount of work (in terms of nanoseconds spent on TPU) that can + // be scheduled in the DMA scheduler at any given point in time. -1 means no + // maximum and all tasks get scheduled immediately. Exceptions are: + // 1. P0 requests. + // 2. When a single inference takes longer than this time and there is no + // other task scheduled (avoid starvation). + const double max_scheduled_work_ns_; + + // The default telemeter implementation (all logging are NOPs). This is used + // by default if no telemeter interface is set via SetTelemeterInterface. + DefaultTelemeter default_telemeter_; + + // The interface to log telemetry. This object is owned by the caller. + // telemeter_interface_ is initialized to default_telemeter_ in the + // constructor, and can be set to the suitable telemter implementation via + // SetTelemeterInterface(). + api::TelemeterInterface* telemeter_interface_; + + // A map of priority to queue of requests waiting to get scheduled. Priorities + // are always 0 or larger and the larger the number the lower the priority. + std::map>> pending_requests_; + + // The thread that runs scheduler for pending requests. + std::thread scheduler_thread_; + + // Mutex to protect scheduler state. + std::mutex scheduler_mutex_; + + // Condition variable to wake up the scheduler for doing more work or quitting + // at destruction time. + std::condition_variable scheduler_wakeup_; + + // If we want the scheduler to check and submit more of the pending requests ( + // if scheduling constraints are met of course). + bool schedule_more_requests_ GUARDED_BY(scheduler_mutex_){false}; + + // If we are destructing the class. This is used for the scheduler thread to + // know when to quit. + bool destructing_ GUARDED_BY(scheduler_mutex_){false}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DRIVER_H_ diff --git a/driver/driver_factory.cc b/driver/driver_factory.cc new file mode 100644 index 0000000..a99ebb0 --- /dev/null +++ b/driver/driver_factory.cc @@ -0,0 +1,140 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver_factory.h" + +#include +#include + +#include "api/driver_options_generated.h" +#include "api/driver_options_helper.h" +#include "driver/config/chip_config.h" +#include "driver/driver.h" +#include "port/defs.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +std::vector DriverProvider::EnumerateSysfs( + const std::string& class_name, api::Chip chip, api::Device::Type type) { + // Look through sysfs for devices of the given class. + // Sysfs paths look like this: /sys/class//_ + // For example, the first beagle device is: /sys/class/apex/apex_0 + // The corresponding device file is assumed to be: /dev/_ + // For example, if /sys/class/apex/apex_0 exists, we look for /dev/apex_0. + // This convention is used for DarwiNN 1.0 devices. + return EnumerateByClass(class_name, class_name, chip, type); +} + +std::vector DriverFactory::Enumerate() { + StdMutexLock lock(&mutex_); + + std::vector device_list; + for (auto& provider : providers_) { + auto provider_supported_devices = provider->Enumerate(); + for (auto& device : provider_supported_devices) { + device_list.push_back(device); + } + } + + return device_list; +} + +util::StatusOr> DriverFactory::CreateDriver( + const api::Device& device) { + return CreateDriver(device, api::DriverOptionsHelper::Defaults()); +} + +util::StatusOr> DriverFactory::CreateDriver( + const api::Device& device, const api::Driver::Options& opaque_options) { + StdMutexLock lock(&mutex_); + + // Deserialize options. + const api::DriverOptions* options = + api::GetDriverOptions(opaque_options.data()); + if (options == nullptr) { + return util::InvalidArgumentError("Invalid Driver::Options instance."); + } + + if (options->version() != Driver::kOptionsVersion) { + return util::InvalidArgumentError("Invalid Driver::Options version."); + } + + // Update verbosity level. + // TODO: Verbosity level should be of per driver instance. +#if !DARWINN_PORT_USE_GOOGLE3 + if (options->verbosity() >= 0) { + ::platforms::darwinn::internal::SetLoggingLevel(options->verbosity()); + } +#endif // !DARWINN_PORT_USE_GOOGLE3 + + for (auto& provider : providers_) { + // Skip if the provider cannot create driver for this device spec. + if (!provider->CanCreate(device)) { + continue; + } + + // Always invoke only the first provider which claims the ability. + if (device.path != kDefaultDevicePath) { + return provider->CreateDriver(device, *options); + } + + // Try to enumerate with this provider. + std::vector device_list = provider->Enumerate(); + // Skip this provider if there is no any device. + if (device_list.empty()) { + continue; + } + + // Create driver associated with the device in the resulting list. + for (const auto& provider_device : device_list) { + if (device.chip == provider_device.chip && + device.type == provider_device.type) { + return provider->CreateDriver(provider_device, *options); + } + } + } + + return util::NotFoundError("Unable to construct driver for device."); +} + +void DriverFactory::RegisterDriverProvider( + std::unique_ptr provider) { + StdMutexLock lock(&mutex_); + providers_.push_back(std::move(provider)); +} + +DriverFactory* DriverFactory::GetOrCreate() { + static std::unique_ptr singleton = + gtl::WrapUnique(new driver::DriverFactory()); + return singleton.get(); +} + +} // namespace driver + +namespace api { + +DriverFactory* DriverFactory::GetOrCreate() { + return driver::DriverFactory::GetOrCreate(); +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver_factory.h b/driver/driver_factory.h new file mode 100644 index 0000000..864f4e5 --- /dev/null +++ b/driver/driver_factory.h @@ -0,0 +1,182 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DRIVER_FACTORY_H_ +#define DARWINN_DRIVER_DRIVER_FACTORY_H_ + +#include +#include // NOLINT +#include +#include +#include + +#include "api/chip.h" +#include "api/driver.h" +#include "api/driver_factory.h" +#include "api/driver_options_generated.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interface for a class that can provide a Driver implementation. +// +// Once implemented, driver providers needs to be registered with the +// DriverFactory using the following function in driver_factory.h +// +// REGISTER_DRIVER_PROVIDER(..). +// +// The subclasses of DriverProvider must implement a static CreateDriverProvider +// function with following signature. The driver provider cannot be registered +// without this function. +// +// static std::unique_ptr CreateDriverProvider(); +class DriverProvider { + public: + DriverProvider() = default; + virtual ~DriverProvider() = default; + + // This class is neither copyable nor movable. + DriverProvider(const DriverProvider&) = delete; + DriverProvider& operator=(const DriverProvider&) = delete; + + // Enumerates all devices available through this provider. + virtual std::vector Enumerate() = 0; + + // Returns true, if the factory can create driver for given device. + virtual bool CanCreate(const api::Device& device) = 0; + + // Returns a driver instance that interfaces with specified device. + // Custom options specified here would override default ones. The exact set of + // possible key-value pairs is provider-specific. + virtual util::StatusOr> CreateDriver( + const api::Device& device, const api::DriverOptions& options) = 0; + + protected: + // Helper function that looks for devices by iterating over directory entries + // in /sys/class//* and matching them against files + // in /dev. + std::vector EnumerateSysfs(const std::string& class_name, + api::Chip chip, + api::Device::Type type); + // Same as above but specifying class name and device name separately: + // /sys/class//* + std::vector EnumerateByClass(const std::string& class_name, + const std::string& device_name, + api::Chip chip, + api::Device::Type type); +}; + +// Enumerates devices and creates drivers for those devices. +class DriverFactory : public api::DriverFactory { + public: + // Creates or returns the singleton instance of the driver factory. + static DriverFactory* GetOrCreate(); + + // This class is neither copyable nor movable. + DriverFactory(const DriverFactory&) = delete; + DriverFactory& operator=(const DriverFactory&) = delete; + + ~DriverFactory() = default; + + // Enumerates all available devices. + std::vector Enumerate() override LOCKS_EXCLUDED(mutex_); + + // Creates a driver instance that interfaces to the specified device. + util::StatusOr> CreateDriver( + const api::Device& device) override LOCKS_EXCLUDED(mutex_); + + // Creates a driver instance that interfaces to the specified device with + // custom options. + util::StatusOr> CreateDriver( + const api::Device& device, const api::Driver::Options& options) override + LOCKS_EXCLUDED(mutex_); + + // Registers a new driver provider. + void RegisterDriverProvider(std::unique_ptr provider) + LOCKS_EXCLUDED(mutex_); + + private: + // Constructor. + DriverFactory() = default; + + // Container for all registered driver providers. + std::vector> providers_ GUARDED_BY(mutex_); + + // Maintains integrity of providers_. + mutable std::mutex mutex_; +}; + +namespace internal { + +// Functions for checking that the DriverProvider has the required +// CreateDriverProvider function. +template +constexpr bool DriverProviderHasCreateDriverProvider() { + typedef std::unique_ptr (*CreateDriverProviderType)(); + return std::is_same::value; +} + +// Provides access to the static functions within a specific subclass +// of DriverProvider. +template +class StaticAccessToDriverProvider { + public: + static_assert(std::is_base_of<::platforms::darwinn::driver::DriverProvider, + DriverProviderSubclass>::value, + "Classes registered with REGISTER_DRIVER_PROVIDER must be " + "subclasses of ::platforms::darwinn::driver::DriverProvider."); + + static_assert( + DriverProviderHasCreateDriverProvider(), + "CreateDriverProvider() must be defined with the correct signature " + "in every DriverProvider."); + + // Provides access to the static function CreateDriverProvider within a + // specific subclass of DriverProvider. + static std::unique_ptr CreateDriverProvider() { + // DriverProviderSubclass must implement this function, since it is not + // implemented in the parent class. + return DriverProviderSubclass::CreateDriverProvider(); + } +}; + +// Registrar that registers an instance of DriverProviderSubclass during +// construction. +template +class DriverProviderRegistrar { + public: + DriverProviderRegistrar() { + auto provider = StaticAccessToDriverProvider< + DriverProviderSubclass>::CreateDriverProvider(); + DriverFactory::GetOrCreate()->RegisterDriverProvider(std::move(provider)); + } +}; + +} // namespace internal + +// Macro for registering DriverProviders. +#define REGISTER_DRIVER_PROVIDER(name) \ + static ::platforms::darwinn::driver::internal::DriverProviderRegistrar \ + DriverProviderRegistrar##name + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DRIVER_FACTORY_H_ diff --git a/driver/driver_factory_darwin.cc b/driver/driver_factory_darwin.cc new file mode 100644 index 0000000..291ac42 --- /dev/null +++ b/driver/driver_factory_darwin.cc @@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver_factory.h" + +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +std::vector DriverProvider::EnumerateByClass( + const std::string& class_name, const std::string& device_name, + api::Chip chip, api::Device::Type type) { + std::vector device_list; + LOG(FATAL) << "EnumerateByClass is not supported on macOS at this time."; + return device_list; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver_factory_default.cc b/driver/driver_factory_default.cc new file mode 100644 index 0000000..9921da7 --- /dev/null +++ b/driver/driver_factory_default.cc @@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver_factory.h" + +#include +#include + +namespace platforms { +namespace darwinn { +namespace driver { + +std::vector DriverProvider::EnumerateByClass( + const std::string& class_name, const std::string& device_name, + api::Chip chip, api::Device::Type type) { + std::vector device_list; + const std::string class_dir_name = "/sys/class/" + class_name; + DIR* dir = opendir(class_dir_name.c_str()); + if (dir == nullptr) { + VLOG(2) << "Failed to open " << class_dir_name << ": " << strerror(errno); + return device_list; // empty list + } + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + std::string entry_name(entry->d_name); + if (entry_name == "." || entry_name == "..") { + continue; + } + if (entry_name.compare(0, device_name.size(), device_name) != 0) { + continue; + } + const std::string dev_file_name = "/dev/" + entry_name; + struct stat statbuf; + int ret = stat(dev_file_name.c_str(), &statbuf); + if (ret != 0) { + VLOG(1) << "Failed to stat " << dev_file_name << ": " << strerror(errno); + continue; + } + if (!S_ISCHR(statbuf.st_mode)) { + LOG(ERROR) << dev_file_name << " is not a character device."; + continue; + } + device_list.push_back({chip, type, dev_file_name}); + } + + closedir(dir); + return device_list; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver_factory_windows.cc b/driver/driver_factory_windows.cc new file mode 100644 index 0000000..9f63088 --- /dev/null +++ b/driver/driver_factory_windows.cc @@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver_factory.h" + +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +std::vector DriverProvider::EnumerateByClass( + const std::string& class_name, const std::string& device_name, + api::Chip chip, api::Device::Type type) { + std::vector device_list; + LOG(FATAL) << "EnumerateByClass is not supported on Windows at this time."; + return device_list; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver_helper.cc b/driver/driver_helper.cc new file mode 100644 index 0000000..8d8a409 --- /dev/null +++ b/driver/driver_helper.cc @@ -0,0 +1,654 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/driver_helper.h" + +#include + +#include +#include // NOLINT +#include // NOLINT +#include +#include +#include +#include +#include + +#include "api/buffer.h" +#include "api/driver.h" +#include "api/package_reference.h" +#include "api/request.h" +#include "driver/executable_util.h" +#include "driver/package_registry.h" +#include "executable/executable_generated.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { +// Pattern to be filled into guard areas, and output data buffers. +constexpr std::array GuardPattern = {0xDE, 0xAD, 0xBE, 0xEF}; + +// Max consecutive matches to count a part of output buffer as not overwritten. +// The shorter, the easier for a false negative (i.e. falsely claiming an error +// has occurred). The longer, the easier for a false positive (i.e. falsely +// claiming no error exists). +constexpr int kMaxConsecutiveMatch = 8; + +template +void FillAreaWithKnownPattern(const Buffer& guard_area, + const T& guard_pattern) { + auto pc = const_cast(guard_area.ptr()); + auto end = guard_area.ptr() + guard_area.size_bytes(); + for (int i = 0; pc != end; ++pc) { + *pc = guard_pattern[i]; + i = (++i) % guard_pattern.size(); + } +} + +template +bool CheckIfAreaIsIntact(const Buffer& guard_area, const T& guard_pattern) { + auto pc = guard_area.ptr(); + auto end = guard_area.ptr() + guard_area.size_bytes(); + for (int i = 0; pc != end; ++pc) { + if (*pc != guard_pattern[i]) { + VLOG(1) << StringPrintf( + "Buffer offset %ld (%p) has been tainted. 0x%X != 0x%X", + pc - guard_area.ptr(), pc, *pc, guard_pattern[i]); + return false; + } + i = (++i) % guard_pattern.size(); + } + + return true; +} + +template +bool CheckIfAreaIsCompletelyOverwritten(const Buffer& output_data, + const T& guard_pattern, + int fail_on_consecutive_match) { + auto pc = output_data.ptr(); + auto end = output_data.ptr() + output_data.size_bytes(); + int count = 0; + for (int i = 0; pc != end; ++pc) { + if (*pc == guard_pattern[i]) { + ++count; + } else { + if (count >= fail_on_consecutive_match) { + break; + } + count = 0; + } + i = (++i) % guard_pattern.size(); + } + + if (count >= fail_on_consecutive_match) { + LOG(WARNING) << StringPrintf( + "Buffer offset %ld (%p) is probably not overwritten by output " + "activations. Running length: %d", + (pc - output_data.ptr()) - count, pc - count, count); + + return false; + } + + return true; +} + +// Converts a buffer to a string. +// Similar to model_compiler_file_util::StoreToString. +std::string ConvertToString(const Buffer::NamedMap& activations) { + std::vector activation_names; + activation_names.reserve(activations.size()); + for (const auto& activation : activations) { + activation_names.push_back(activation.first); + } + // Named activation buffers are sorted by name in output. + std::sort(activation_names.begin(), activation_names.end()); + + std::string output; + for (const auto& name : activation_names) { + const auto& batched_output = activations.at(name); + for (const auto& output_batch : batched_output) { + const auto output_batch_string = + std::string(reinterpret_cast(output_batch.ptr()), + output_batch.size_bytes()); + output.insert(output.end(), output_batch_string.begin(), + output_batch_string.end()); + } + } + + return output; +} + +util::Status WriteToFile(const std::string& output_file_name, + const std::string& output_content) { + std::ofstream record_file(output_file_name, std::ios_base::out); + + if (record_file.is_open()) { + record_file.write(output_content.c_str(), output_content.size()); + record_file.close(); + + if (!record_file) { + return util::InternalError("Failed writing execution record."); + } + } else { + return util::InternalError("Failed opening file for dumping output."); + } + + return util::OkStatus(); +} + +// Returns true if the actual output matches with expected on the count for each +// unique byte value. This is used to provide a hint that a data mismatch is +// probably caused by re-layout issues. +bool MatchesWithoutRelayout(const uint8* actual_output, + const uint8* expected_output, size_t size) { + constexpr int kNumPossibleValues = std::numeric_limits::max() + 1; + std::array byte_count_actual_output{}, + byte_count_expected_output{}; + + // Count the number of each byte value in the outputs. + for (size_t i = 0; i < size; i++) { + ++byte_count_actual_output[actual_output[i]]; + ++byte_count_expected_output[expected_output[i]]; + } + + // Make sure they match. + for (int i = 0; i < kNumPossibleValues; ++i) { + if (byte_count_expected_output[i] != byte_count_actual_output[i]) { + return false; + } + } + + return true; +} + +} // namespace + +DriverHelper::DriverHelper(std::unique_ptr driver, + int max_pending_requests, + bool prefill_output_tensors, + size_t guard_area_size_bytes) + : driver_(std::move(driver)), + max_pending_requests_(max_pending_requests), + prefill_output_tensors_(prefill_output_tensors), + guard_area_size_bytes_(guard_area_size_bytes) {} + +bool DriverHelper::IsOpen() const { return driver_->IsOpen(); } + +bool DriverHelper::IsError() const { return driver_->IsError(); } + +util::Status DriverHelper::Cancel(std::shared_ptr request) { + return driver_->Cancel(std::move(request)); +} + +util::Status DriverHelper::CancelAllRequests() { + return driver_->CancelAllRequests(); +} + +uint64_t DriverHelper::allocation_alignment_bytes() const { + return driver_->allocation_alignment_bytes(); +} + +Buffer DriverHelper::MakeBuffer(size_t size_bytes) const { + return driver_->MakeBuffer(size_bytes); +} + +util::Status DriverHelper::Open(bool debug_mode, bool context_lost) { + return driver_->Open(debug_mode, context_lost); +} + +util::StatusOr +DriverHelper::RegisterExecutableFile(const std::string& executable_filename) { + return driver_->RegisterExecutableFile(executable_filename); +} + +util::StatusOr +DriverHelper::RegisterExecutableSerialized( + const std::string& executable_content) { + return driver_->RegisterExecutableSerialized(executable_content); +} + +util::StatusOr +DriverHelper::RegisterExecutableSerialized(const char* executable_content, + size_t length) { + return driver_->RegisterExecutableSerialized(executable_content, length); +} + +util::Status DriverHelper::UnregisterExecutable( + const api::PackageReference* executable_ref) { + return driver_->UnregisterExecutable(executable_ref); +} + +util::StatusOr> DriverHelper::CreateRequest( + const api::PackageReference* executable_ref) { + return driver_->CreateRequest(executable_ref); +} + +util::Status DriverHelper::Execute(std::shared_ptr request) { + return driver_->Execute(request); +} + +util::Status DriverHelper::Execute( + const std::vector>& requests) { + return driver_->Execute(requests); +} + +util::Status DriverHelper::Submit(std::shared_ptr request, + api::Request::Done done_callback) { + // Request completion callback. + // Note that the whole callback functor is cloned into this one, so it's + // available when done. + auto start_time = std::chrono::steady_clock::now(); + auto wrapped_done = [this, done_callback, start_time]( + int id, const util::Status& status) { + auto roundtrip_time_ms = std::chrono::duration( + std::chrono::steady_clock::now() - start_time) + .count(); + VLOG(1) << StringPrintf("Request [%d] complete. Status=%s. Took %f ms.", id, + status.ToString().c_str(), roundtrip_time_ms); + StdMutexLock lock(&mutex_); + CHECK_GT(pending_requests_, 0); + --pending_requests_; + roundtrip_times_ms_.push_back(roundtrip_time_ms); + cv_.notify_all(); + + auto verification_start_time = std::chrono::steady_clock::now(); + done_callback(id, status); + auto verification_time_ms = + std::chrono::duration( + std::chrono::steady_clock::now() - verification_start_time) + .count(); + verification_times_ms_.push_back(verification_time_ms); + }; + + VLOG(1) << StringPrintf("Request [%d] submitting.", request->id()); + return driver_->Submit(std::move(request), std::move(wrapped_done)); +} + +util::Status DriverHelper::Submit(const TestVector& test_vector, int batches) { + Buffer::NamedMap input; + Buffer::NamedMap expected_output; + Buffer::NamedMap output; + Buffer::NamedMap output_with_guard_areas; + + const auto* executable_ref = test_vector.executable_reference(); + + if (batches <= 0) { + batches = executable_ref->BatchSize(); + } + + // Compiler dumps input and expected output buffers in alphabetical order. + auto input_names = executable_ref->InputLayerNames(); + auto output_names = executable_ref->OutputLayerNames(); + std::sort(input_names.begin(), input_names.end()); + std::sort(output_names.begin(), output_names.end()); + + // Prepare input buffers. + if (!input_names.empty()) { + const std::string& input_string = test_vector.GetInput(); + int input_base = 0; + for (const auto& input_name : input_names) { + VLOG(5) << StringPrintf("Preparing input buffers for %s.", + input_name.c_str()); + ASSIGN_OR_RETURN(const int input_size, + executable_ref->InputLayerPaddedSizeBytes(input_name)); + auto batch_buffer = driver_->MakeBuffer(input_size * batches); + for (int i = 0; i < batches; ++i) { + // The input file contains a number of input buffers that matches the + // native batch size. Add them to the request in order, but loop back + // to the first buffer again if we go past the end. + const int input_pos = + input_base + (i % executable_ref->BatchSize()) * input_size; + CHECK_LE(input_pos + input_size, input_string.size()); + + Buffer input_buffer = batch_buffer.Slice(i * input_size, input_size); + std::copy(input_string.begin() + input_pos, + input_string.begin() + input_pos + input_size, + input_buffer.ptr()); + + input[input_name].push_back(std::move(input_buffer)); + } + input_base += input_size * executable_ref->BatchSize(); + } + } + + // Prepare output and expected output buffers. + if (!output_names.empty()) { + const std::string& expected_output_string = test_vector.GetExpectedOutput(); + int output_base = 0; + for (const auto& output_name : output_names) { + VLOG(5) << StringPrintf("Preparing output buffers for %s.", + output_name.c_str()); + ASSIGN_OR_RETURN(const int output_size, + executable_ref->OutputLayerSizeBytes(output_name)); + + // Allocate buffer with guard area. + size_t output_plus_guard_area_size = + output_size + (guard_area_size_bytes_ * 2); + auto batch_buffer = + driver_->MakeBuffer(output_plus_guard_area_size * batches); + for (int i = 0; i < batches; ++i) { + // The expected output contains a number of output buffers that matches + // the native batch size (and the input file). Again, we add output + // buffers to the request in order, with wrap-around at the end. + const int output_pos = + output_base + (i % executable_ref->BatchSize()) * output_size; + CHECK_LE(output_pos + output_size, expected_output_string.size()); + + // Prepare expected output buffer. + auto expected_output_buffer = driver_->MakeBuffer(output_size); + std::copy(expected_output_string.begin() + output_pos, + expected_output_string.begin() + output_pos + output_size, + expected_output_buffer.ptr()); + + expected_output[output_name].push_back( + std::move(expected_output_buffer)); + + // Generated output buffer. + if (guard_area_size_bytes_ == 0) { + // No guard area. + auto output_buffer = batch_buffer.Slice(i * output_size, output_size); + + if (prefill_output_tensors_) { + FillAreaWithKnownPattern(output_buffer, GuardPattern); + } + + output[output_name].push_back(std::move(output_buffer)); + } else { + // Allocate buffer with guard area. + auto output_buffer_with_guard = batch_buffer.Slice( + i * output_plus_guard_area_size, output_plus_guard_area_size); + + Buffer leading_guard_area(output_buffer_with_guard.ptr(), + guard_area_size_bytes_); + FillAreaWithKnownPattern(leading_guard_area, GuardPattern); + + // memcpy is going to work as well, but having a separate buffer and + // call to fill up is slightly more flexible regarding size and + // pattern. + Buffer trailing_guard_area(output_buffer_with_guard.ptr() + + guard_area_size_bytes_ + output_size, + guard_area_size_bytes_); + FillAreaWithKnownPattern(trailing_guard_area, GuardPattern); + + Buffer output_buffer = + Buffer(output_buffer_with_guard.ptr() + guard_area_size_bytes_, + output_size); + + if (prefill_output_tensors_) { + FillAreaWithKnownPattern(output_buffer, GuardPattern); + } + + output[output_name].push_back(std::move(output_buffer)); + output_with_guard_areas[output_name].push_back( + std::move(output_buffer_with_guard)); + } + } + output_base += output_size * executable_ref->BatchSize(); + } + } + + return Submit(test_vector.name(), test_vector.executable_reference(), + test_vector.output_file_name(), input, expected_output, output, + output_with_guard_areas); +} + +util::Status DriverHelper::Submit( + const std::string& tag, const api::PackageReference* executable_ref, + const Buffer::NamedMap& input, const Buffer::NamedMap& output, + const Buffer::NamedMap& output_with_guard_areas, + api::Request::Done request_done) { + ASSIGN_OR_RETURN(auto request, CreateRequest(executable_ref)); + + // Attach inputs to the request. + for (auto& named_input : input) { + for (auto& input_buffer : named_input.second) { + RETURN_IF_ERROR(request->AddInput(named_input.first, input_buffer)); + } + } + + // Attach outputs to the request. + for (auto& named_output : output) { + for (auto& output_buffer : named_output.second) { + RETURN_IF_ERROR(request->AddOutput(named_output.first, output_buffer)); + } + } + + // Increase pending and total requests before submission, so the completion + // callback can make correct calculations. If batching is enabled, each + // request holds one batch which is multiple inferences. + { + StdCondMutexLock lock(&mutex_); + if (total_requests_ == 0) { + first_submit_ = std::chrono::steady_clock::now(); + } + ++pending_requests_; + ++total_requests_; + } + + // Submit. + VLOG(1) << StringPrintf("Request [%d, %s] submitting.", request->id(), + tag.c_str()); + + auto submit_status = Submit(request, std::move(request_done)); + + { + StdCondMutexLock lock(&mutex_); + + if (!submit_status.ok()) { + // Decrease request counters, as submission has failed. + --pending_requests_; + --total_requests_; + return submit_status; + } else { + // Waits synchronously, if we reach maximum pending requests. + while (pending_requests_ >= max_pending_requests_) { + cv_.wait(lock); + } + } + } + + return util::Status(); // OK. +} + +util::Status DriverHelper::Submit(const std::string& tag, + const api::PackageReference* executable_ref, + const Buffer::NamedMap& input, + const Buffer::NamedMap& expected_output, + const Buffer::NamedMap& output) { + Buffer::NamedMap no_guard_areas; + return Submit(tag, executable_ref, /*output_file_name=*/std::string{}, input, + expected_output, output, no_guard_areas); +} + +util::Status DriverHelper::Submit( + const std::string& tag, const api::PackageReference* executable_ref, + const std::string& output_file_name, const Buffer::NamedMap& input, + const Buffer::NamedMap& expected_output, const Buffer::NamedMap& output, + const Buffer::NamedMap& output_with_guard_areas) { + // Note that all the Buffer::NamedMap instances are cloned into the functor + // when it's created, and hence they can be used to verify correctness of + // result when the functor is actually executed. Also note that the Buffer + // objects used are all "host" buffers with shared_ptr, so a memory block + // would only be released when the last Buffer instance pointing to that + // memory block is destructed. + auto request_done = [this, tag, executable_ref, output, + output_with_guard_areas, expected_output, + output_file_name](int id, const util::Status& status) { + if (!status.ok()) { + LOG(INFO) << StringPrintf("Request [%d, %s] failed: %s", id, tag.c_str(), + status.error_message().c_str()); + return; + } + + // Compare each output buffer. + for (const auto& output_name : executable_ref->OutputLayerNames()) { + for (int i = 0; i < expected_output.at(output_name).size(); ++i) { + const auto& output_buffer = output.at(output_name)[i]; + const auto& expected_output_buffer = expected_output.at(output_name)[i]; + + CHECK_EQ(output_buffer.size_bytes(), + expected_output_buffer.size_bytes()); + + if (prefill_output_tensors_) { + CHECK(CheckIfAreaIsCompletelyOverwritten(output_buffer, GuardPattern, + kMaxConsecutiveMatch)); + } + + if (guard_area_size_bytes_ > 0) { + CHECK(!output_with_guard_areas.empty()); + for (auto& named_output : output) { + auto it_with_guard_areas = + output_with_guard_areas.find(named_output.first); + if (it_with_guard_areas == output_with_guard_areas.end()) { + LOG(FATAL) << "Cannot find output [" << named_output.first + << "] in guard area info"; + } + + const std::vector& device_outputs = named_output.second; + const std::vector& device_outputs_with_guard_areas = + it_with_guard_areas->second; + + CHECK_EQ(device_outputs.size(), + device_outputs_with_guard_areas.size()); + + for (size_t i = 0; i < device_outputs.size(); ++i) { + // Check the leading guard area is not touched. + Buffer leading_guard_area( + device_outputs_with_guard_areas[i].ptr(), + guard_area_size_bytes_); + CHECK(CheckIfAreaIsIntact(leading_guard_area, GuardPattern)) + << "Output [" << named_output.first << "][" << i + << "]. Leading guard area has been tainted"; + + // Check the trailing guard area is not touched. + // memcmp is going to work as well, but having a separate buffer + // and call to verify is slightly more flexible regarding size + // and pattern. + Buffer trailing_guard_area( + device_outputs_with_guard_areas[i].ptr() + + guard_area_size_bytes_ + device_outputs[i].size_bytes(), + guard_area_size_bytes_); + CHECK(CheckIfAreaIsIntact(trailing_guard_area, GuardPattern)) + << "Output [" << named_output.first << "][" << i + << "]. Trailing guard area has been tainted"; + } + } + } + + if (memcmp(output_buffer.ptr(), expected_output_buffer.ptr(), + expected_output_buffer.size_bytes()) != 0) { + if (MatchesWithoutRelayout(output_buffer.ptr(), + expected_output_buffer.ptr(), + expected_output_buffer.size_bytes())) { + LOG(ERROR) << StringPrintf( + "Mismatched result, but every unique byte value has the same " + "number of elements in both data sets. " + "This is probably an error related to re-layout\n"); + } + + for (int element = 0; element < expected_output_buffer.size_bytes(); + ++element) { + if (output_buffer.ptr()[element] != + expected_output_buffer.ptr()[element]) { + if (!output_file_name.empty()) { + CHECK_OK( + WriteToFile(output_file_name, ConvertToString(output))); + } + + LOG(FATAL) << StringPrintf( + "Mismatched result: output_name = %s, batch = %d, " + "size_bytes = %zd.\nFirst mismatched element at %d: %x vs " + "%x", + output_name.c_str(), i, expected_output_buffer.size_bytes(), + element, output_buffer.ptr()[element], + expected_output_buffer.ptr()[element]); + } + } + } + } + } + LOG(INFO) << StringPrintf("Request [%d, %s] verified.", id, tag.c_str()); + }; + + return Submit(tag, executable_ref, input, output, output_with_guard_areas, + std::move(request_done)); +} + +util::Status DriverHelper::Close(api::Driver::ClosingMode mode) { + StdCondMutexLock lock(&mutex_); + while (pending_requests_ > 0) { + VLOG(5) << StringPrintf("Waiting for %d pending requests.", + pending_requests_); + cv_.wait(lock); + } + + auto last_submit_complete = std::chrono::steady_clock::now(); + auto diff_millis = std::chrono::duration( + last_submit_complete - first_submit_) + .count(); + LOG(INFO) << StringPrintf( + "%d requests processed in %.3f ms at a rate of %.3f requests per " + "second or %.3f ms per request.", + total_requests_, diff_millis, total_requests_ * 1000.0 / diff_millis, + diff_millis / total_requests_); + auto sum_verification_time_ms = std::accumulate( + verification_times_ms_.begin(), verification_times_ms_.end(), 0.0); + diff_millis -= sum_verification_time_ms; + LOG(INFO) << StringPrintf( + "Total process time excluding verification is %.3f ms at a rate of " + "%.3f requests per second or %.3f ms per request.", + diff_millis, total_requests_ * 1000.0 / diff_millis, + diff_millis / total_requests_); + LOG(INFO) << StringPrintf( + "Average inference time (As observed by each request which grows with " + "the number of pending_requests) : %.3f ms.", + std::accumulate(roundtrip_times_ms_.begin(), roundtrip_times_ms_.end(), + 0.0) / + roundtrip_times_ms_.size()); + + return driver_->Close(mode); +} + +void DriverHelper::SetFatalErrorCallback(FatalErrorCallback callback) { + driver_->SetFatalErrorCallback(std::move(callback)); +} + +void DriverHelper::SetThermalWarningCallback(ThermalWarningCallback callback) { + driver_->SetThermalWarningCallback(std::move(callback)); +} + +util::Status DriverHelper::SetRealtimeMode(bool on) { + return util::FailedPreconditionError( + "This driver does not support real-time mode."); +} + +util::Status DriverHelper::SetExecutableTiming( + const api::PackageReference* executable, const api::Timing& timing) { + return util::FailedPreconditionError( + "This driver does not support real-time mode."); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/driver_helper.h b/driver/driver_helper.h new file mode 100644 index 0000000..a9428f7 --- /dev/null +++ b/driver/driver_helper.h @@ -0,0 +1,208 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_DRIVER_HELPER_H_ +#define DARWINN_DRIVER_DRIVER_HELPER_H_ + +#include + +#include // NOLINT +#include // NOLINT +#include +#include +#include +#include + +#include "api/buffer.h" +#include "api/chip.h" +#include "api/driver.h" +#include "api/package_reference.h" +#include "api/request.h" +#include "api/telemeter_interface.h" +#include "api/timing.h" +#include "driver/executable_util.h" +#include "driver/package_registry.h" +#include "driver/test_vector.h" +#include "executable/executable_generated.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Wraps a driver instance with additional functions that performs tests and +// verify results. +class DriverHelper : public api::Driver { + public: + DriverHelper(std::unique_ptr driver, int max_pending_requests, + bool prefill_output_tensors = false, + size_t guard_area_size_bytes = 0); + + // Destructor. Waits for pending tasks to avoid Submit callbacks + // acquiring otherwise-destructed mutex_. See b/111616778. + ~DriverHelper() override { + if (IsOpen()) CHECK_OK(Close(api::Driver::ClosingMode::kGraceful)); + } + + util::Status Open(bool debug_mode, bool context_lost = false) final + LOCKS_EXCLUDED(mutex_); + + util::Status Close(api::Driver::ClosingMode mode) final + LOCKS_EXCLUDED(mutex_); + + bool IsOpen() const final LOCKS_EXCLUDED(mutex_); + + bool IsError() const final; + + util::StatusOr RegisterExecutableFile( + const std::string& executable_filename) final; + + util::StatusOr RegisterExecutableSerialized( + const std::string& executable_content) final; + util::StatusOr RegisterExecutableSerialized( + const char* executable_content, size_t length) final; + + util::Status UnregisterExecutable( + const api::PackageReference* executable_ref) final; + + util::StatusOr> CreateRequest( + const api::PackageReference* executable_ref) final; + + util::Status Submit(std::shared_ptr request, + api::Request::Done done_callback) final; + + util::Status Execute(std::shared_ptr request) final; + + util::Status Execute( + const std::vector>& requests) final; + + util::Status Cancel(std::shared_ptr request) final; + + util::Status CancelAllRequests() final; + + uint64_t allocation_alignment_bytes() const final; + + Buffer MakeBuffer(size_t size_bytes) const final; + + void SetFatalErrorCallback(FatalErrorCallback callback) final; + + void SetThermalWarningCallback(ThermalWarningCallback callback) final; + + util::Status SetExecutionPreference(const api::PackageReference* package, + ExecutionPreference preference) final { + return util::OkStatus(); + } + + // Extensions to the Device interface to facilitate easier testing. + + // Submits an inference request with given test vector. + util::Status Submit(const TestVector& test_vector, int batches) + LOCKS_EXCLUDED(mutex_); + + // Submits an inference request and execute the specified callback on + // completion. |tag| is a user friendly name for tracking this request + // (typically the model name). + util::Status Submit(const std::string& tag, + const api::PackageReference* executable_ref, + const Buffer::NamedMap& input, + const Buffer::NamedMap& output, + const Buffer::NamedMap& output_with_guard_areas, + api::Request::Done request_done) LOCKS_EXCLUDED(mutex_); + + // Submits an inference request and verify output, with optional guard area + // sorrounding the output buffers. Dumps the output upon mismatch, if + // output_file_name is not empty. + util::Status Submit( + const std::string& tag, const api::PackageReference* executable_ref, + const std::string& output_file_name, const Buffer::NamedMap& input, + const Buffer::NamedMap& expected_output, const Buffer::NamedMap& output, + const Buffer::NamedMap& output_with_guard_areas) LOCKS_EXCLUDED(mutex_); + + // Submits an inference request and verify output. + util::Status Submit(const std::string& tag, + const api::PackageReference* executable_ref, + const Buffer::NamedMap& input, + const Buffer::NamedMap& expected_output, + const Buffer::NamedMap& output) LOCKS_EXCLUDED(mutex_); + + util::Status SetRealtimeMode(bool on) override; + + util::Status SetExecutableTiming(const api::PackageReference* executable, + const api::Timing& timing) override; + + void SetTelemeterInterface( + api::TelemeterInterface* telemeter_interface) override {} + + void UpdateOperationalSettings(const OperationalSettings& settings) override { + driver_->UpdateOperationalSettings(settings); + } + + private: + // Wrapped driver instance. + std::unique_ptr driver_; + + // Maximum number of pending requests. + const int max_pending_requests_{1}; + + // Current number of pending requests. + int pending_requests_ GUARDED_BY(mutex_){0}; + + // Total number of requests processed so far. + int total_requests_ GUARDED_BY(mutex_){0}; + + // Condition variable to synchronously wait for pending requests. + std::condition_variable cv_ GUARDED_BY(mutex_); + + // Guards pending_requests_, cv_ and other internal states. + mutable std::mutex mutex_; + + // Time at which first submit was called. + std::chrono::steady_clock::time_point first_submit_; + + // A vector of roundtrip times for all requests in milliseconds. Roundtrip + // time is measured from when driver::submit is called until the callback is + // first received. + std::vector roundtrip_times_ms_; + + // A vector of verification times for all requests in milliseconds. + // Verification time is measured from when the callback is first received + // until the callback is completed. + std::vector verification_times_ms_; + + // If true, the output tensors are pre-filled with known data pattern. + // This helps catch incomplete output activations, i.e. when any parts of the + // output memory region are not overwritten. + const bool prefill_output_tensors_; + + // If non-zero, leading and trailing guard areas would be allocated for every + // output buffer, and filled with known data pattern. These guard areas would + // then be checked when a request is completed, to detect data overflow. + // The size should be page-aligned for PCIe use cases. + // Note that in cases the driver always makes a copy of the output buffers, + // this mechanism would only catch driver-caused overruns. + const size_t guard_area_size_bytes_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_DRIVER_HELPER_H_ diff --git a/driver/executable_util.cc b/driver/executable_util.cc new file mode 100644 index 0000000..a60c968 --- /dev/null +++ b/driver/executable_util.cc @@ -0,0 +1,219 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility functions for working with executable.fbs. + +#include "driver/executable_util.h" + +#include + +#include +#include +#include +#include + +#include "api/buffer.h" +#include "executable/executable_generated.h" +#include "port/array_slice.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/macros.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; + +// Align n to the nearest multiple of n, where n is a power of 2. +int AlignNext(int value, int n) { + CHECK_EQ(0, n & (n - 1)); // must be power of 2. + return (value + n) & ~(n - 1); +} + +// Copy the low |num_bits| from |src| into |dst| at offset |dst_offset_bits|. +// |dst_offset_bits + num_bits| must be less than or equal to 8. +// Returns the original |src| but with the low |num_bits| bits shifted out. +uint32 CopyUint8LowBits(uint32 src, int dst_offset_bit, int num_bits, + uint8* dst) { + CHECK_LE(dst_offset_bit + num_bits, CHAR_BIT); + + // Mask the low |num_bits| bits from |src| and assign it to |dst| at offset + // |dst_offset_bit|. + const uint8 src_mask = (1 << num_bits) - 1; + const uint8 dst_mask = src_mask << dst_offset_bit; + + *dst = (*dst & ~dst_mask) | (src & src_mask) << dst_offset_bit; + + return src >> num_bits; // shift out bits already set. +} + +void LinkBatchedAddress(Description target, const std::string& name, + const std::vector& addresses, + const Vector>* field_offsets, + gtl::MutableArraySlice encoded_buffer) { + if (field_offsets == nullptr) { + return; + } + + for (const auto& field_offset : *field_offsets) { + const auto& meta = field_offset->meta(); + if (meta->desc() != target) { + continue; + } + + if (meta->name()->str() != name) { + continue; + } + + const int batch = meta->batch(); + CHECK(batch < addresses.size()); + const uint64 link_address = addresses[batch]; + + uint32 immediate_value; + if (meta->position() == Position_LOWER_32BIT) { + VLOG(3) << StringPrintf( + "Linking %s[%d]: 0x%016llx", name.c_str(), batch, + static_cast( // NOLINT(runtime/int) + link_address)); + immediate_value = link_address & kuint32max; + } else { + CHECK_EQ(meta->position(), Position_UPPER_32BIT); + immediate_value = (link_address >> 32) & kuint32max; + } + ExecutableUtil::CopyUint32(encoded_buffer, field_offset->offset_bit(), + immediate_value); + } +} + +} // namespace + +void ExecutableUtil::CopyUint32(gtl::MutableArraySlice buffer, + int offset_bit, uint32 original_value) { + // Track current destination bit offset. + int next_dst_offset_bit = offset_bit; + + // Tracks remaining bits that needs to be set. + int remaining_bits = sizeof(original_value) * CHAR_BIT; + + // Value that needs to be set, bits that are set are shifted out. + int next_value = original_value; + + while (remaining_bits > 0) { + // Sets enough bits to align to the next 8 bit boundary. + int num_bits_to_set = + std::min(AlignNext(next_dst_offset_bit, CHAR_BIT) - next_dst_offset_bit, + remaining_bits); + + // Offset byte and bit offset with in the byte. + int dst_byte = next_dst_offset_bit / CHAR_BIT; + int dst_bit = next_dst_offset_bit % CHAR_BIT; + + // Copy lower |num_bits_to_set| from next_value into the destination byte + // at the specified offset. + next_value = CopyUint8LowBits(next_value, dst_bit, num_bits_to_set, + &buffer[dst_byte]); + + remaining_bits -= num_bits_to_set; + next_dst_offset_bit += num_bits_to_set; + } +} + +void ExecutableUtil::LinkScratchAddress( + uint64 scratch_address, const Vector>* field_offsets, + gtl::MutableArraySlice encoded_buffer) { + if (field_offsets == nullptr) { + return; + } + + for (const auto& field_offset : *field_offsets) { + const auto& meta = field_offset->meta(); + if (meta->desc() != Description_BASE_ADDRESS_SCRATCH) { + continue; + } + + // TODO: Add support for batch. + CHECK_EQ(meta->batch(), 0); + + uint32 immediate_value; + if (meta->position() == Position_LOWER_32BIT) { + VLOG(3) << StringPrintf( + "Linking Scratch: 0x%016llx", + static_cast( // NOLINT(runtime/int) + scratch_address)); + immediate_value = scratch_address & kuint32max; + } else { + CHECK_EQ(meta->position(), Position_UPPER_32BIT); + immediate_value = (scratch_address >> 32) & kuint32max; + } + + CopyUint32(encoded_buffer, field_offset->offset_bit(), immediate_value); + } +} + +void ExecutableUtil::LinkParameterAddress( + uint64 parameter_address, const Vector>* field_offsets, + gtl::MutableArraySlice encoded_buffer) { + if (field_offsets == nullptr) { + return; + } + + for (const auto& field_offset : *field_offsets) { + const auto& meta = field_offset->meta(); + if (meta->desc() != Description_BASE_ADDRESS_PARAMETER) { + continue; + } + + uint32 immediate_value; + if (meta->position() == Position_LOWER_32BIT) { + VLOG(3) << StringPrintf( + "Linking Parameter: 0x%016llx", + static_cast( // NOLINT(runtime/int) + parameter_address)); + immediate_value = parameter_address & kuint32max; + } else { + CHECK_EQ(meta->position(), Position_UPPER_32BIT); + immediate_value = (parameter_address >> 32) & kuint32max; + } + + CopyUint32(encoded_buffer, field_offset->offset_bit(), immediate_value); + } +} + +void ExecutableUtil::LinkInputAddress( + const std::string& input_name, const std::vector& input_addresses, + const Vector>* field_offsets, + gtl::MutableArraySlice encoded_buffer) { + LinkBatchedAddress(Description_BASE_ADDRESS_INPUT_ACTIVATION, input_name, + input_addresses, field_offsets, encoded_buffer); +} + +void ExecutableUtil::LinkOutputAddress( + const std::string& output_name, const std::vector& output_addresses, + const Vector>* field_offsets, + gtl::MutableArraySlice encoded_buffer) { + LinkBatchedAddress(Description_BASE_ADDRESS_OUTPUT_ACTIVATION, output_name, + output_addresses, field_offsets, encoded_buffer); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/executable_util.h b/driver/executable_util.h new file mode 100644 index 0000000..ae2baec --- /dev/null +++ b/driver/executable_util.h @@ -0,0 +1,83 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_EXECUTABLE_UTIL_H_ +#define DARWINN_DRIVER_EXECUTABLE_UTIL_H_ + +#include +#include +#include +#include + +#include "api/buffer.h" +#include "executable/executable_generated.h" +#include "port/array_slice.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Utility functions for working with executable.fbs. +class ExecutableUtil { + public: + // Processes the input instruction stream and generates an output instruction + // stream with the meta fields populated with the given scratch address. Due + // to the way flatbuffers are packed, field_offsets can be nullptr which is + // treated the same as empty vector in this function. + static void LinkScratchAddress( + uint64 scratch_address, + const flatbuffers::Vector>* + field_offsets, + gtl::MutableArraySlice encoded_buffer); + + // Processes the input instruction stream and generates an output instruction + // stream with the meta fields populated with the given host addresses. Due + // to the way flatbuffers are packed, field_offsets can be nullptr which is + // treated the same as empty vector in this function. + static void LinkParameterAddress( + uint64 parameter_address, + const flatbuffers::Vector>* + field_offsets, + gtl::MutableArraySlice encoded_buffer); + + static void LinkInputAddress( + const std::string& input_name, const std::vector& input_addresses, + const flatbuffers::Vector>* + field_offsets, + gtl::MutableArraySlice encoded_buffer); + + static void LinkOutputAddress( + const std::string& output_name, + const std::vector& output_addresses, + const flatbuffers::Vector>* + field_offsets, + gtl::MutableArraySlice encoded_buffer); + + // Convenience function to set a uint32 value on the specified bitoffset. + static void CopyUint32(gtl::MutableArraySlice buffer, int offset_bit, + uint32 value); + + private: + // Static class. + ExecutableUtil(); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_EXECUTABLE_UTIL_H_ diff --git a/driver/hardware_structures.h b/driver/hardware_structures.h new file mode 100644 index 0000000..d17c678 --- /dev/null +++ b/driver/hardware_structures.h @@ -0,0 +1,171 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Various Hardware Structures and Constants. + +#ifndef DARWINN_DRIVER_HARDWARE_STRUCTURES_H_ +#define DARWINN_DRIVER_HARDWARE_STRUCTURES_H_ + +#include +#include + +#include "port/integral_types.h" +#include "port/macros.h" + +#if defined(ATTRIBUTE_PACKED) & !defined(ABSL_ATTRIBUTE_PACKED) +#define ABSL_ATTRIBUTE_PACKED ATTRIBUTE_PACKED +#endif + +namespace platforms { +namespace darwinn { +namespace driver { + +// DarwiNN page table management constants. + +// DarwiNN virtual address format +// Simple addressing: +// [63] | [62:25] | [24:12] | [11:0] +// 0 | Reserved [0...] | Page Table Index | Page Offset +// +// Extended addressing: +// [63] | [62:34] | [33:21] | [20:12] | [11:0] +// 1 | Reserved [0] | Extended PT Index | Host Table Index | Page Offset + +// The MSB. +static constexpr uint64 kExtendedVirtualAddressBit = (1ULL << 63); +// Simple addressing: page table index. +static constexpr uint64 kSimplePageTableIndexShiftBits = 12; +static constexpr uint64 kSimplePageTableIndexWidthBits = 13; +// Extended addressing: page table index. +static constexpr uint64 kExtendedPageTableIndexShiftBits = 21; +static constexpr uint64 kExtendedPageTableIndexWidthBits = 13; +// Extended addressing: host page table index. +static constexpr uint64 kExtendedHostPageTableIndexShiftBits = 12; +static constexpr uint64 kExtendedHostPageTableIndexWidthBits = 9; +static constexpr uint64 kExtendedHostPageTableSizePerPage = + (1ULL << kExtendedHostPageTableIndexWidthBits); + +// Host page info. 4096 bytes. +static constexpr uint64 kHostPageShiftBits = 12; +static constexpr uint64 kHostPageSize = (1ULL << kHostPageShiftBits); + +// Manage valid / Invalid page table entries. +static constexpr uint64 kValidPageTableEntryMask = 1; +static constexpr uint64 kInvalidPageTableEntryValue = 0; + +// Manage bar number and offsets. +static constexpr uint64 kDarwinnBarNumber = 2; +static constexpr uint64 kDarwinnBarSize = 1ULL * 1024ULL * 1024ULL; + +// Defines a descriptor to fetch instructions in the host queue. +struct alignas(16) HostQueueDescriptor { + uint64 address; + uint32 size_in_bytes; + uint32 reserved; +} ABSL_ATTRIBUTE_PACKED; +static_assert(sizeof(HostQueueDescriptor) == 16, "Must be 16 bytes."); + +// Defines the status block that hardware updates. +struct alignas(16) HostQueueStatusBlock { + // The value of completed_head pointer when the status block was updated. + uint32 completed_head_pointer; + // A bit to indicate that fatal error has occured for the host queue. Using + // uint32 to align it to 8B boundary. + uint32 fatal_error; + uint64_t reserved; +} ABSL_ATTRIBUTE_PACKED; +static_assert(sizeof(HostQueueStatusBlock) == 16, "Must be 16 bytes."); + +// An MSIX table entry as shown in Figure 6-11 in PCI local bus specification +// rev 3.0 document. +struct MsixTableEntry { + // An address to perform PCIe write at for an interrupt. + uint64 message_address; + // Data to send in PCIe write for an interrupt. + uint32 message_data; + // LSB is used to mask an interrupt. Other bits are reserved. + uint32 vector_control; +} ABSL_ATTRIBUTE_PACKED; +static_assert(sizeof(MsixTableEntry) == 16, "Must be 16 bytes."); + +// Size in bytes addressable by a single extended page table entry. +// When kHostPageSize is 4K, this is 2MB. +static constexpr uint64 kExtendedPageTableEntryAddressableBytes = + kExtendedHostPageTableSizePerPage * kHostPageSize; + +// Size in bytes of the configured DarwiNN extended address space range. +// Must be a multiple of |kExtendedPageTableEntryAddressableBytes|. The maximum +// addressable extended address space range is 16 GB. However, this is +// restricted to 4GB to avoid using 64 bit math in the scalar core. +// See: go/g.d/1a9uNlUCrEu43L31v_gRENgCjKW4MMA-B-L64cR8z3I4 +static constexpr uint64 kExtendedAddressSpaceStart = 0x8000000000000000L; +static constexpr uint64 kExtendedAddressSpaceSizeBytes = + (4096 * 1024 * 1024ULL); +static constexpr int kExtendedAddressSpacePrefixWidthBits = 32; +static_assert( + (kExtendedAddressSpaceStart >> kExtendedAddressSpacePrefixWidthBits) == + ((kExtendedAddressSpaceStart + kExtendedAddressSpaceSizeBytes - 1) >> + kExtendedAddressSpacePrefixWidthBits), + "Extended address space range cannot span 4 GB boundaries."); +static_assert((kExtendedAddressSpaceSizeBytes % + kExtendedPageTableEntryAddressableBytes == 0), + "Must be multiple of extended host page"); + +// The upper 32 bits of the extended address space segment. +static constexpr uint32 kExtendedAddressSpacePrefix = + kExtendedAddressSpaceStart >> kExtendedAddressSpacePrefixWidthBits; + +// Simple / Extended page table entry split. +// At the minimum, simple address space needs 256 * 4kB = 1MB. +static constexpr int kMinNumSimplePageTableEntries = 256; + +// At the maximum, 2048 * 2MB = 4GB is reserved for extended address space. +static constexpr int kMaxNumExtendedPageTableEntries = + kExtendedAddressSpaceSizeBytes / kExtendedPageTableEntryAddressableBytes; + +// Returns number of simple page table entries given page table size. +inline int GetNumSimplePageTableEntries(int num_page_table_entries) { + const int num_simple_entries = + num_page_table_entries - kMaxNumExtendedPageTableEntries; + return std::max(num_simple_entries, kMinNumSimplePageTableEntries); +} + +// Returns number of extended page table entries given page table size. +inline int GetNumExtendedPageTableEntries(int num_page_table_entries) { + return num_page_table_entries - + GetNumSimplePageTableEntries(num_page_table_entries); +} + +// Run control settings for tiles and scalar core. +enum class RunControl { + kMoveToIdle = 0, + kMoveToRun = 1, + kMoveToHalt = 2, + kMoveToSingleStep = 3, +}; + +// Run status settings for tiles and scalar core. +enum class RunStatus { + kIdle = 0, + kRun = 1, + kSingleStep = 2, + kHalting = 3, + kHalted = 4, +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_HARDWARE_STRUCTURES_H_ diff --git a/driver/instruction_buffers.cc b/driver/instruction_buffers.cc new file mode 100644 index 0000000..8666873 --- /dev/null +++ b/driver/instruction_buffers.cc @@ -0,0 +1,122 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/instruction_buffers.h" + +#include + +#include "driver/aligned_allocator.h" +#include "driver/device_buffer_mapper.h" +#include "driver/executable_util.h" +#include "executable/executable_generated.h" +#include "port/logging.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::flatbuffers::VectorLength; + +InstructionBuffers::InstructionBuffers( + Allocator* const allocator, + const Vector>& instruction_bitstreams) { + // Allocate and create an aligned copy of instruction bitstream. + buffers_.reserve(VectorLength(&instruction_bitstreams)); + + for (const auto& chunk : instruction_bitstreams) { + auto buffer = allocator->MakeBuffer(chunk->bitstream()->Length()); + buffers_.push_back(std::move(buffer)); + memcpy(buffers_.back().ptr(), chunk->bitstream()->data(), + chunk->bitstream()->Length()); + } + VLOG(10) << "InstructionBuffers created."; +} + +InstructionBuffers::~InstructionBuffers() { + buffers_.clear(); + VLOG(10) << "InstructionBuffers destroyed."; +} + +void InstructionBuffers::LinkInstructionBuffers( + const DeviceBuffer& parameter_device_buffer, + DeviceBufferMapper* device_buffer_mapper, + const Vector>& instruction_bitstreams) { + TRACE_SCOPE("InstructionBuffers::LinkInstructionBuffers"); + + // Update the instruction stream to link the input, output and parameter + // addresses. + for (int i = 0; i < VectorLength(&instruction_bitstreams); ++i) { + // Link scratch address if necessary. + // Note: we may be able to optimize this scratch address linking like the + // parameters below, so that we don't re-link it every time since we have + // more control on the scratch memory and could keep it at the same address. + // It's unclear how easy to make this change at this point though, and we + // could revisit this later if needed (yuchicheng). + if (device_buffer_mapper->GetScratchDeviceBuffer().IsValid()) { + ExecutableUtil::LinkScratchAddress( + device_buffer_mapper->GetScratchDeviceBuffer().device_address(), + instruction_bitstreams.Get(i)->field_offsets(), + gtl::MutableArraySlice( + buffers_[i].ptr(), + VectorLength(instruction_bitstreams.Get(i)->bitstream()))); + } + + // Link parameters if necessary. + if (parameter_device_buffer.IsValid()) { + const uint64 linked_parameter_address = + parameter_device_buffer.device_address(); + ExecutableUtil::LinkParameterAddress( + linked_parameter_address, + instruction_bitstreams.Get(i)->field_offsets(), + gtl::MutableArraySlice( + buffers_[i].ptr(), + VectorLength(instruction_bitstreams.Get(i)->bitstream()))); + } + + for (const auto& name_and_mapped_input : + device_buffer_mapper->GetInputDeviceBuffers()) { + std::vector linked_input_addresses; + for (const auto& mapped_input : name_and_mapped_input.second) { + linked_input_addresses.push_back(mapped_input.device_address()); + } + ExecutableUtil::LinkInputAddress( + name_and_mapped_input.first, linked_input_addresses, + instruction_bitstreams.Get(i)->field_offsets(), + gtl::MutableArraySlice( + buffers_[i].ptr(), + VectorLength(instruction_bitstreams.Get(i)->bitstream()))); + } + + for (const auto& name_and_mapped_output : + device_buffer_mapper->GetOutputDeviceBuffers()) { + std::vector linked_output_addresses; + for (const auto& mapped_output : name_and_mapped_output.second) { + linked_output_addresses.push_back(mapped_output.device_address()); + } + ExecutableUtil::LinkOutputAddress( + name_and_mapped_output.first, linked_output_addresses, + instruction_bitstreams.Get(i)->field_offsets(), + gtl::MutableArraySlice( + buffers_[i].ptr(), + VectorLength(instruction_bitstreams.Get(i)->bitstream()))); + } + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/instruction_buffers.h b/driver/instruction_buffers.h new file mode 100644 index 0000000..494bc54 --- /dev/null +++ b/driver/instruction_buffers.h @@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INSTRUCTION_BUFFERS_H_ +#define DARWINN_DRIVER_INSTRUCTION_BUFFERS_H_ + +#include +#include + +#include "api/buffer.h" +#include "driver/allocator.h" +#include "driver/device_buffer_mapper.h" +#include "executable/executable_generated.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Wrapper class for handling instruction buffers. +class InstructionBuffers { + public: + // Constructs the instruction buffers by allocating and copying instruction + // stream to host memory. + InstructionBuffers( + platforms::darwinn::driver::Allocator *allocator, + const flatbuffers::Vector> + &instruction_bitstreams); + ~InstructionBuffers(); + + // Links scratch address, parameters, input, and output. + void LinkInstructionBuffers( + const DeviceBuffer ¶meter_device_buffer, + DeviceBufferMapper *device_buffer_mapper, + const flatbuffers::Vector> + &instruction_bitstreams); + + // Returns the reference to the buffer vector. + const std::vector &GetBuffers() const { return buffers_; } + + private: + // The actual buffers which holds the instruction stream. + std::vector buffers_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INSTRUCTION_BUFFERS_H_ diff --git a/driver/interrupt/BUILD b/driver/interrupt/BUILD new file mode 100644 index 0000000..66e95d4 --- /dev/null +++ b/driver/interrupt/BUILD @@ -0,0 +1,89 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Interrupt functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "interrupt_handler", + hdrs = ["interrupt_handler.h"], + deps = ["//port"], +) + +cc_library( + name = "wire_interrupt_handler", + srcs = ["wire_interrupt_handler.cc"], + hdrs = ["wire_interrupt_handler.h"], + deps = [ + ":interrupt_handler", + "//driver/config", + "//driver/registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "interrupt_controller_interface", + hdrs = ["interrupt_controller_interface.h"], + deps = ["//port"], +) + +cc_library( + name = "dummy_interrupt_controller", + hdrs = ["dummy_interrupt_controller.h"], + deps = [ + ":interrupt_controller_interface", + "//port", + ], +) + +cc_library( + name = "interrupt_controller", + srcs = ["interrupt_controller.cc"], + hdrs = ["interrupt_controller.h"], + deps = [ + ":interrupt_controller_interface", + "//driver/config", + "//driver/registers", + "//port", + ], +) + +cc_library( + name = "grouped_interrupt_controller", + srcs = ["grouped_interrupt_controller.cc"], + hdrs = ["grouped_interrupt_controller.h"], + deps = [ + ":interrupt_controller", + ":interrupt_controller_interface", + "//port", + ], +) + +cc_library( + name = "top_level_interrupt_manager", + srcs = ["top_level_interrupt_manager.cc"], + hdrs = ["top_level_interrupt_manager.h"], + deps = [ + ":interrupt_controller_interface", + "//port", + ], +) diff --git a/driver/interrupt/dummy_interrupt_controller.h b/driver/interrupt/dummy_interrupt_controller.h new file mode 100644 index 0000000..823cf19 --- /dev/null +++ b/driver/interrupt/dummy_interrupt_controller.h @@ -0,0 +1,54 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_DUMMY_INTERRUPT_CONTROLLER_H_ +#define DARWINN_DRIVER_INTERRUPT_DUMMY_INTERRUPT_CONTROLLER_H_ + +#include "driver/interrupt/interrupt_controller_interface.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Dummy class that does nothing upon interrupt related requests. +class DummyInterruptController : public InterruptControllerInterface { + public: + explicit DummyInterruptController(int num_interrupts) + : InterruptControllerInterface(num_interrupts) {} + + // This class is neither copyable nor movable. + DummyInterruptController(const DummyInterruptController&) = delete; + DummyInterruptController& operator=(const DummyInterruptController&) = delete; + + ~DummyInterruptController() = default; + + util::Status EnableInterrupts() override { + return util::Status(); // OK + } + + util::Status DisableInterrupts() override { + return util::Status(); // OK + } + + util::Status ClearInterruptStatus(int id) override { + return util::Status(); // OK + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_DUMMY_INTERRUPT_CONTROLLER_H_ diff --git a/driver/interrupt/grouped_interrupt_controller.cc b/driver/interrupt/grouped_interrupt_controller.cc new file mode 100644 index 0000000..6e3c402 --- /dev/null +++ b/driver/interrupt/grouped_interrupt_controller.cc @@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/interrupt/grouped_interrupt_controller.h" + +#include "driver/interrupt/interrupt_controller_interface.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +GroupedInterruptController::GroupedInterruptController( + std::vector>* + interrupt_controllers) + : InterruptControllerInterface(interrupt_controllers->size()), + interrupt_controllers_([interrupt_controllers]() { + CHECK(interrupt_controllers != nullptr); + return std::move(*interrupt_controllers); + }()) {} + +util::Status GroupedInterruptController::EnableInterrupts() { + for (auto& interrupt_controller : interrupt_controllers_) { + RETURN_IF_ERROR(interrupt_controller->EnableInterrupts()); + } + return util::Status(); // OK +} + +util::Status GroupedInterruptController::DisableInterrupts() { + for (auto& interrupt_controller : interrupt_controllers_) { + RETURN_IF_ERROR(interrupt_controller->DisableInterrupts()); + } + return util::Status(); // OK +} + +util::Status GroupedInterruptController::ClearInterruptStatus(int id) { + if (id < interrupt_controllers_.size()) { + return interrupt_controllers_[id]->ClearInterruptStatus(); + } + return util::FailedPreconditionError( + StringPrintf("Unknown interrupt id: %d", id)); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/interrupt/grouped_interrupt_controller.h b/driver/interrupt/grouped_interrupt_controller.h new file mode 100644 index 0000000..688a826 --- /dev/null +++ b/driver/interrupt/grouped_interrupt_controller.h @@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_GROUPED_INTERRUPT_CONTROLLER_H_ +#define DARWINN_DRIVER_INTERRUPT_GROUPED_INTERRUPT_CONTROLLER_H_ + +#include +#include + +#include "driver/interrupt/interrupt_controller.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Helper class for enabling/disabling interrupts, and clearing interrupt +// status. +class GroupedInterruptController : public InterruptControllerInterface { + public: + // "interrupt_controllers" will be empty after construction. + explicit GroupedInterruptController( + std::vector>* + interrupt_controllers); + + // This class is neither copyable nor movable. + GroupedInterruptController(const GroupedInterruptController&) = delete; + GroupedInterruptController& operator=(const GroupedInterruptController&) = + delete; + + ~GroupedInterruptController() = default; + + // Enable/disables interrupts. + util::Status EnableInterrupts() override; + util::Status DisableInterrupts() override; + + // Clears interrupt status register to notify that host has received the + // interrupt. + util::Status ClearInterruptStatus(int id) override; + + private: + // CSR offsets. + const std::vector> + interrupt_controllers_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_GROUPED_INTERRUPT_CONTROLLER_H_ diff --git a/driver/interrupt/interrupt_controller.cc b/driver/interrupt/interrupt_controller.cc new file mode 100644 index 0000000..bffbc05 --- /dev/null +++ b/driver/interrupt/interrupt_controller.cc @@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/interrupt/interrupt_controller.h" + +#include "driver/registers/registers.h" +#include "port/logging.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +InterruptController::InterruptController( + const config::InterruptCsrOffsets& csr_offsets, Registers* registers, + int num_interrupts) + : InterruptControllerInterface(num_interrupts), + csr_offsets_(csr_offsets), + registers_(registers) { + CHECK(registers != nullptr); +} + +util::Status InterruptController::EnableInterrupts() { + const uint64 enable_all = (1ULL << NumInterrupts()) - 1; + return registers_->Write(csr_offsets_.control, enable_all); +} + +util::Status InterruptController::DisableInterrupts() { + constexpr uint64 kDisableAll = 0; + return registers_->Write(csr_offsets_.control, kDisableAll); +} + +util::Status InterruptController::ClearInterruptStatus(int id) { + // Interrupt status register has W0C policy meaning that writing 0 clears the + // bit, while writing 1 does not have any effect. + const uint64 clear_bit = ~(1ULL << id); + + uint64 value = (1ULL << NumInterrupts()) - 1; + value &= clear_bit; + return registers_->Write(csr_offsets_.status, value); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/interrupt/interrupt_controller.h b/driver/interrupt/interrupt_controller.h new file mode 100644 index 0000000..8290421 --- /dev/null +++ b/driver/interrupt/interrupt_controller.h @@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_H_ +#define DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_H_ + +#include "driver/config/interrupt_csr_offsets.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/registers/registers.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Helper class for enabling/disabling interrupts, and clearing interrupt +// status. +class InterruptController : public InterruptControllerInterface { + public: + InterruptController(const config::InterruptCsrOffsets& csr_offsets, + Registers* registers, int num_interrupts = 1); + + // This class is neither copyable nor movable. + InterruptController(const InterruptController&) = delete; + InterruptController& operator=(const InterruptController&) = delete; + + ~InterruptController() = default; + + // Enable/disables interrupts. + util::Status EnableInterrupts() override; + util::Status DisableInterrupts() override; + + // Clears interrupt status register to notify that host has received the + // interrupt. + util::Status ClearInterruptStatus(int id) override; + + private: + // CSR offsets. + const config::InterruptCsrOffsets& csr_offsets_; + + // CSR interface. + Registers* const registers_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_H_ diff --git a/driver/interrupt/interrupt_controller_interface.h b/driver/interrupt/interrupt_controller_interface.h new file mode 100644 index 0000000..07d0ecf --- /dev/null +++ b/driver/interrupt/interrupt_controller_interface.h @@ -0,0 +1,58 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_INTERFACE_H_ +#define DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_INTERFACE_H_ + +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interface for enabling/disabling interrupts, and clearing interrupt status. +class InterruptControllerInterface { + public: + explicit InterruptControllerInterface(int num_interrupts) + : num_interrupts_(num_interrupts) {} + + // This class is neither copyable nor movable. + InterruptControllerInterface(const InterruptControllerInterface&) = delete; + InterruptControllerInterface& operator=(const InterruptControllerInterface&) = + delete; + + virtual ~InterruptControllerInterface() = default; + + // Enable/disables interrupts. + virtual util::Status EnableInterrupts() = 0; + virtual util::Status DisableInterrupts() = 0; + + // Clears interrupt status register to notify that host has received the + // interrupt. + virtual util::Status ClearInterruptStatus(int id) = 0; + util::Status ClearInterruptStatus() { return ClearInterruptStatus(/*id=*/0); } + + // Returns number of interrupts controlled by this interface. + int NumInterrupts() const { return num_interrupts_; } + + private: + // Number of interrupts enabled/disabled/cleared by this interface. + const int num_interrupts_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_INTERRUPT_CONTROLLER_INTERFACE_H_ diff --git a/driver/interrupt/interrupt_handler.h b/driver/interrupt/interrupt_handler.h new file mode 100644 index 0000000..dfd206e --- /dev/null +++ b/driver/interrupt/interrupt_handler.h @@ -0,0 +1,68 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_INTERRUPT_HANDLER_H_ +#define DARWINN_DRIVER_INTERRUPT_INTERRUPT_HANDLER_H_ + +#include + +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interrupt identifiers. +enum Interrupt { + DW_INTERRUPT_INSTR_QUEUE = 0, + DW_INTERRUPT_INPUT_ACTV_QUEUE = 1, + DW_INTERRUPT_PARAM_QUEUE = 2, + DW_INTERRUPT_OUTPUT_ACTV_QUEUE = 3, + DW_INTERRUPT_SC_HOST_0 = 4, + DW_INTERRUPT_SC_HOST_1 = 5, + DW_INTERRUPT_SC_HOST_2 = 6, + DW_INTERRUPT_SC_HOST_3 = 7, + DW_INTERRUPT_TOP_LEVEL_0 = 8, + DW_INTERRUPT_TOP_LEVEL_1 = 9, + DW_INTERRUPT_TOP_LEVEL_2 = 10, + DW_INTERRUPT_TOP_LEVEL_3 = 11, + DW_INTERRUPT_FATAL_ERR = 12, + DW_INTERRUPT_COUNT = 13, + + // Aliases. + DW_INTERRUPT_SC_HOST_BASE = DW_INTERRUPT_SC_HOST_0, + DW_INTERRUPT_TOP_LEVEL_BASE = DW_INTERRUPT_TOP_LEVEL_0, +}; + +// Interface for handling interrupts. +class InterruptHandler { + public: + using Handler = std::function; + + virtual ~InterruptHandler() = default; + + // Open / Close the interrupt handler. + virtual util::Status Open() = 0; + virtual util::Status Close(bool in_error) = 0; + util::Status Close() { return Close(/*in_error=*/false); } + + // Registers interrupt. + virtual util::Status Register(Interrupt interrupt, Handler handler) = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_INTERRUPT_HANDLER_H_ diff --git a/driver/interrupt/top_level_interrupt_manager.cc b/driver/interrupt/top_level_interrupt_manager.cc new file mode 100644 index 0000000..333ebe7 --- /dev/null +++ b/driver/interrupt/top_level_interrupt_manager.cc @@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/interrupt/top_level_interrupt_manager.h" + +#include "port/status.h" +#include "port/status_macros.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status TopLevelInterruptManager::EnableInterrupts() { + RETURN_IF_ERROR(interrupt_controller_->EnableInterrupts()); + return DoEnableInterrupts(); +} + +util::Status TopLevelInterruptManager::DisableInterrupts() { + RETURN_IF_ERROR(interrupt_controller_->DisableInterrupts()); + return DoDisableInterrupts(); +} + +util::Status TopLevelInterruptManager::HandleInterrupt(int id) { + RETURN_IF_ERROR(DoHandleInterrupt(id)); + return interrupt_controller_->ClearInterruptStatus(id); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/interrupt/top_level_interrupt_manager.h b/driver/interrupt/top_level_interrupt_manager.h new file mode 100644 index 0000000..31210b4 --- /dev/null +++ b/driver/interrupt/top_level_interrupt_manager.h @@ -0,0 +1,76 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_TOP_LEVEL_INTERRUPT_MANAGER_H_ +#define DARWINN_DRIVER_INTERRUPT_TOP_LEVEL_INTERRUPT_MANAGER_H_ + +#include + +#include "driver/interrupt/interrupt_controller_interface.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Base class for top level interrupt management. +class TopLevelInterruptManager { + public: + explicit TopLevelInterruptManager( + std::unique_ptr interrupt_controller) + : interrupt_controller_(std::move(interrupt_controller)) {} + virtual ~TopLevelInterruptManager() = default; + + // Opens/closes the controller. + virtual util::Status Open() { + return util::Status(); // OK + } + virtual util::Status Close() { + return util::Status(); // OK + } + + // Enable/disables interrupts. + util::Status EnableInterrupts(); + util::Status DisableInterrupts(); + + // Handles interrupt. + util::Status HandleInterrupt(int id); + + // Returns number of top level interrupts. + int NumInterrupts() const { return interrupt_controller_->NumInterrupts(); } + + protected: + // Actually enables/disables interrupts, which are system-specific. + virtual util::Status DoEnableInterrupts() { + return util::Status(); // OK + } + virtual util::Status DoDisableInterrupts() { + return util::Status(); // OK + } + + // Actually handles interrupts, which are system-specific. + virtual util::Status DoHandleInterrupt(int id) { + return util::Status(); // OK + } + + private: + // Interrupt controller. + std::unique_ptr interrupt_controller_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_TOP_LEVEL_INTERRUPT_MANAGER_H_ diff --git a/driver/interrupt/wire_interrupt_handler.cc b/driver/interrupt/wire_interrupt_handler.cc new file mode 100644 index 0000000..f2d0d3e --- /dev/null +++ b/driver/interrupt/wire_interrupt_handler.cc @@ -0,0 +1,344 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/interrupt/wire_interrupt_handler.h" + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "driver/config/common_csr_helper.h" +#include "driver/config/wire_csr_offsets.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +constexpr const uint64 kQuiescedRegValue = 0xdeadfeeddeadfeedLL; + +WireInterruptHandler::WireInterruptHandler( + Registers* registers, const config::WireCsrOffsets& wire_csr_offsets, + int num_wires) + : registers_(registers), + wire_csr_offsets_(wire_csr_offsets), + num_wires_(num_wires) { + CHECK(registers != nullptr); + // Only supports 1 wire and 3 wire interrupt as of now. + CHECK(num_wires_ == 1 || num_wires_ == 3); + interrupts_.resize(Interrupt::DW_INTERRUPT_COUNT); +} + +util::Status WireInterruptHandler::ValidateOpenState(bool open) const { + if (open_ != open) { + return util::FailedPreconditionError( + "Invalid state in WireInterruptHandler."); + } + return util::Status(); // OK +} + +util::Status WireInterruptHandler::Open() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/false)); + open_ = true; + + for (int i = 0; i < DW_INTERRUPT_COUNT; ++i) { + interrupts_[i] = nullptr; + } + + return util::Status(); // OK +} + +util::Status WireInterruptHandler::Close(bool in_error) { + // If in error, interrupt handler is already serving fatal error, and mutex is + // already locked. To avoid deadlock, return immediately. + if (in_error) { + return util::Status(); // OK + } + + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/true)); + open_ = false; + + for (int i = 0; i < DW_INTERRUPT_COUNT; ++i) { + interrupts_[i] = nullptr; + } + return util::Status(); // OK +} + +void WireInterruptHandler::MaskInterrupt(int interrupt_id, bool mask) { + config::registers::WireIntBitArray bit_array_helper_int_mask(ReadMaskArray()); + switch (interrupt_id) { + case DW_INTERRUPT_INSTR_QUEUE: + bit_array_helper_int_mask.set_instruction_queue(mask); + break; + case DW_INTERRUPT_SC_HOST_0: + bit_array_helper_int_mask.set_sc_host_0(mask); + break; + case DW_INTERRUPT_SC_HOST_1: + bit_array_helper_int_mask.set_sc_host_1(mask); + break; + case DW_INTERRUPT_SC_HOST_2: + bit_array_helper_int_mask.set_sc_host_2(mask); + break; + case DW_INTERRUPT_SC_HOST_3: + bit_array_helper_int_mask.set_sc_host_3(mask); + break; + case DW_INTERRUPT_FATAL_ERR: + bit_array_helper_int_mask.set_fatal_err(mask); + break; + default: + LOG(FATAL) << "MaskInterrupt: unhandled interrupt id: " << interrupt_id; + } + CHECK_OK(WriteMaskArray(bit_array_helper_int_mask.raw())); +} + +void WireInterruptHandler::InvokeInterruptWithMask(int interrupt_id) { + StdMutexLock lock(&mutex_); + if (interrupts_[interrupt_id]) { + // Mask and unmask interrupt due to b/111367622. + // This is Noronha-specific, but shouldn't hurt for other architectures. + MaskInterrupt(interrupt_id, true); + interrupts_[interrupt_id](); + MaskInterrupt(interrupt_id, false); + } +} + +void WireInterruptHandler::InvokeInterrupt(int interrupt_id) { + StdMutexLock lock(&mutex_); + if (interrupts_[interrupt_id]) { + interrupts_[interrupt_id](); + } +} + +uint64 WireInterruptHandler::ReadPendingBitArray() { + return registers_->Read(wire_csr_offsets_.wire_int_pending_bit_array) + .ValueOrDie(); +} + +uint64 WireInterruptHandler::ReadMaskArray() { + return registers_->Read(wire_csr_offsets_.wire_int_mask_array).ValueOrDie(); +} + +util::Status WireInterruptHandler::WriteMaskArray(uint64 value) { + return registers_->Write(wire_csr_offsets_.wire_int_mask_array, value); +} + +void WireInterruptHandler::HandlePlatformSingleWireInterrupt() { + config::registers::WireIntBitArray bit_array_helper(ReadPendingBitArray()); + config::registers::WireIntBitArray bit_array_helper_int_mask(ReadMaskArray()); + + while (bit_array_helper.raw() != 0) { + if (bit_array_helper.raw() == kQuiescedRegValue) { + // We re-entered this loop after chip was put in clock gating state, + // hence nothing to do. + break; + } + + if (bit_array_helper.instruction_queue()) { + InvokeInterrupt(DW_INTERRUPT_INSTR_QUEUE); + bit_array_helper_int_mask.set_instruction_queue(0); + } + + if (bit_array_helper.sc_host_0()) { + InvokeInterrupt(DW_INTERRUPT_SC_HOST_0); + bit_array_helper_int_mask.set_sc_host_0(0); + } + + if (bit_array_helper.sc_host_1()) { + InvokeInterrupt(DW_INTERRUPT_SC_HOST_1); + bit_array_helper_int_mask.set_sc_host_1(0); + } + + if (bit_array_helper.sc_host_2()) { + InvokeInterrupt(DW_INTERRUPT_SC_HOST_2); + bit_array_helper_int_mask.set_sc_host_2(0); + } + + if (bit_array_helper.sc_host_3()) { + InvokeInterrupt(DW_INTERRUPT_SC_HOST_3); + bit_array_helper_int_mask.set_sc_host_3(0); + } + + if (bit_array_helper.fatal_err()) { + InvokeInterrupt(DW_INTERRUPT_FATAL_ERR); + bit_array_helper_int_mask.set_fatal_err(0); + } + + if (bit_array_helper.top_level_0() || bit_array_helper.top_level_1() || + bit_array_helper.top_level_2() || bit_array_helper.top_level_3()) { + LOG(WARNING) << "Unsupported top level interrupt raised."; + } + + if (bit_array_helper.param_queue() || bit_array_helper.input_actv_queue() || + bit_array_helper.output_actv_queue()) { + LOG(WARNING) << "Unsupported queue interrupt raised."; + } + + // Mask bits are set in kernel-land : unmask interrupts when user-land + // handler has completed. + bit_array_helper = + config::registers::WireIntBitArray(ReadPendingBitArray()); + + CHECK_OK(WriteMaskArray(bit_array_helper_int_mask.raw())); + } +} + +void WireInterruptHandler::HandleMsi3WireInterrupt(int wire_id) { + CHECK_LT(wire_id, num_wires_); + + switch (wire_id) { + // Scalar core interrupt 0. + case 0: + InvokeInterruptWithMask(DW_INTERRUPT_SC_HOST_0); + break; + + // Instruction queue interrupt. + case 1: + InvokeInterruptWithMask(DW_INTERRUPT_INSTR_QUEUE); + break; + + // Remaining. + default: { + config::registers::WireIntBitArray bit_array_helper( + ReadPendingBitArray()); + + while (bit_array_helper.raw() != 0) { + if (bit_array_helper.raw() == kQuiescedRegValue) { + // We re-entered this loop after chip was put in clock gating state, + // hence nothing to do. + break; + } + + if (bit_array_helper.sc_host_1()) { + InvokeInterruptWithMask(DW_INTERRUPT_SC_HOST_1); + } + + if (bit_array_helper.sc_host_2()) { + InvokeInterruptWithMask(DW_INTERRUPT_SC_HOST_2); + } + + if (bit_array_helper.sc_host_3()) { + InvokeInterruptWithMask(DW_INTERRUPT_SC_HOST_3); + } + + if (bit_array_helper.fatal_err()) { + InvokeInterruptWithMask(DW_INTERRUPT_FATAL_ERR); + } + + if (bit_array_helper.top_level_0() || bit_array_helper.top_level_1() || + bit_array_helper.top_level_2() || bit_array_helper.top_level_3()) { + LOG(WARNING) << "Unsupported top level interrupt raised."; + } + + if (bit_array_helper.param_queue() || + bit_array_helper.input_actv_queue() || + bit_array_helper.output_actv_queue()) { + LOG(WARNING) << "Unsupported queue interrupt raised."; + } + + // Mask bits are set in kernel-land : unmask interrupts when user-land + // handler has completed. + bit_array_helper = + config::registers::WireIntBitArray(ReadPendingBitArray()); + } + + break; + } + } +} + +void WireInterruptHandler::InvokeAllPendingInterrupts(int wire_id) { + if (num_wires_ == 3) { + return HandleMsi3WireInterrupt(wire_id); + } else { + return HandlePlatformSingleWireInterrupt(); + } +} + +util::Status WireInterruptHandler::Register(Interrupt interrupt, + Handler handler) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/true)); + + interrupts_[interrupt] = std::move(handler); + return util::Status(); // OK; +} + +PollingWireInterruptHandler::PollingWireInterruptHandler( + Registers* registers, const config::WireCsrOffsets& wire_csr_offsets, + std::function sleep) + : WireInterruptHandler(registers, wire_csr_offsets, /*num_wires=*/1), + sleep_(std::move(sleep)) {} + +util::Status PollingWireInterruptHandler::Open() { + StdMutexLock lock(&mutex_); + if (enabled_) { + return util::FailedPreconditionError( + "Invalid state in WireInterruptHandler."); + } + RETURN_IF_ERROR(WireInterruptHandler::Open()); + enabled_ = true; + + std::thread event_thread(&PollingWireInterruptHandler::PollInterrupts, this); + thread_ = std::move(event_thread); + + return util::Status(); // OK +} + +util::Status PollingWireInterruptHandler::Close(bool in_error) { + { + StdMutexLock lock(&mutex_); + if (!enabled_) { + return util::FailedPreconditionError( + "Invalid state in WireInterruptHandler."); + } + enabled_ = false; + } + + // Wait for thread to exit. + thread_.join(); + return WireInterruptHandler::Close(in_error); +} + +bool PollingWireInterruptHandler::IsEnabled() const { + StdMutexLock lock(&mutex_); + return enabled_; +} + +void PollingWireInterruptHandler::PollInterrupts() { + VLOG(5) << StringPrintf("Interrupt monitor thread enter."); + TRACE_START_THREAD("PollingWireInterruptHandler"); + + do { + sleep_(); + InvokeAllPendingInterrupts(/*wire_id=*/0); + } while (IsEnabled()); + + VLOG(5) << StringPrintf("Interrupt monitor thread exit."); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/interrupt/wire_interrupt_handler.h b/driver/interrupt/wire_interrupt_handler.h new file mode 100644 index 0000000..a748b02 --- /dev/null +++ b/driver/interrupt/wire_interrupt_handler.h @@ -0,0 +1,145 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_INTERRUPT_WIRE_INTERRUPT_HANDLER_H_ +#define DARWINN_DRIVER_INTERRUPT_WIRE_INTERRUPT_HANDLER_H_ + +#include // NOLINT +#include // NOLINT +#include + +#include "driver/config/wire_csr_offsets.h" +#include "driver/interrupt/interrupt_handler.h" +#include "driver/registers/registers.h" +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Wire Interrupt handler implementation. +class WireInterruptHandler : public InterruptHandler { + public: + // Default close to avoid name hiding. + using InterruptHandler::Close; + + WireInterruptHandler(Registers* registers, + const config::WireCsrOffsets& wire_csr_offsets, + int num_wires); + ~WireInterruptHandler() override = default; + + // This class is neither copyable nor movable. + WireInterruptHandler(const WireInterruptHandler&) = delete; + WireInterruptHandler& operator=(const WireInterruptHandler&) = delete; + + util::Status Open() LOCKS_EXCLUDED(mutex_) override; + util::Status Close(bool in_error) LOCKS_EXCLUDED(mutex_) override; + + util::Status Register(Interrupt interrupt, Handler handler) + LOCKS_EXCLUDED(mutex_) override; + + // Checks the pending bit array and invoke interrupts. + virtual void InvokeAllPendingInterrupts(int wire_id); + + private: + // Invokes the handler for the specified interrupt. + void InvokeInterrupt(int interrupt_id) LOCKS_EXCLUDED(mutex_); + + // Invokes the handler for the specified interrupt, and masks the interrput + // during processing. + void InvokeInterruptWithMask(int interrupt_id) LOCKS_EXCLUDED(mutex_); + + // Masks and unmasks given interrput sources. + void MaskInterrupt(int interrupt_id, bool mask) SHARED_LOCKS_REQUIRED(mutex_); + + // Performs CSR read access. + uint64 ReadPendingBitArray(); + uint64 ReadMaskArray(); + + // Performs CSR write access. + util::Status WriteMaskArray(uint64 value); + + // Validates that interrupt handler is |open|. + util::Status ValidateOpenState(bool open) const SHARED_LOCKS_REQUIRED(mutex_); + + // Handles single wire interrupt on platform devices. + void HandlePlatformSingleWireInterrupt(); + + // Handls 3 wire MSI interrupt. + void HandleMsi3WireInterrupt(int wire_id); + + // Register access. + Registers* const registers_; + + // CSR offsets. + const config::WireCsrOffsets wire_csr_offsets_; + + // Number of wires. + const int num_wires_; + + // Mutex that guards interrupts_, open_ state.; + mutable std::mutex mutex_; + + // Tracks open state. + bool open_ GUARDED_BY(mutex_){false}; + + // Registered interrupts. + std::vector interrupts_ GUARDED_BY(mutex_); +}; + +// Wire Interrupt handler implementation that polls the pending bit array. +class PollingWireInterruptHandler : public WireInterruptHandler { + public: + // Default close to avoid name hiding. + using InterruptHandler::Close; + + PollingWireInterruptHandler(Registers* registers, + const config::WireCsrOffsets& wire_csr_offsets, + std::function sleep); + ~PollingWireInterruptHandler() override = default; + + // This class is neither copyable nor movable. + PollingWireInterruptHandler(const PollingWireInterruptHandler&) = delete; + PollingWireInterruptHandler& operator=(const PollingWireInterruptHandler&) = + delete; + + util::Status Open() LOCKS_EXCLUDED(mutex_) override; + util::Status Close(bool in_error) LOCKS_EXCLUDED(mutex_) override; + + private: + // Returns true, if polling is enabled. + bool IsEnabled() const LOCKS_EXCLUDED(mutex_); + + // Polls and dispatches interrupts. + void PollInterrupts(); + + // Mutex that guards enabled_ state.; + mutable std::mutex mutex_; + + // Tracks enabled state. + bool enabled_ GUARDED_BY(mutex_){false}; + + // Thread for polling interrupts. + std::thread thread_; + + // Sleep function. + std::function sleep_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_INTERRUPT_WIRE_INTERRUPT_HANDLER_H_ diff --git a/driver/kernel/BUILD b/driver/kernel/BUILD new file mode 100644 index 0000000..d4f6680 --- /dev/null +++ b/driver/kernel/BUILD @@ -0,0 +1,121 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Kernel driver specific functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "kernel_mmu_mapper", + srcs = ["kernel_mmu_mapper.cc"], + hdrs = ["kernel_mmu_mapper.h"], + deps = [ + ":linux_gasket_ioctl", + "//driver:hardware_structures", + "//driver/memory:dma_direction", + "//driver/memory:mmu_mapper", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "kernel_event", + srcs = ["kernel_event.cc"], + hdrs = ["kernel_event.h"], + deps = [ + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "kernel_event_handler", + srcs = ["kernel_event_handler.cc"], + hdrs = ["kernel_event_handler.h"], + deps = [ + ":kernel_event", + ":linux_gasket_ioctl", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "kernel_interrupt_handler", + srcs = ["kernel_interrupt_handler.cc"], + hdrs = ["kernel_interrupt_handler.h"], + deps = [ + "//driver/interrupt:interrupt_handler", + "//driver/kernel:kernel_event_handler", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "kernel_wire_interrupt_handler", + srcs = ["kernel_wire_interrupt_handler.cc"], + hdrs = ["kernel_wire_interrupt_handler.h"], + deps = [ + ":kernel_event_handler", + "//driver/config", + "//driver/interrupt:interrupt_handler", + "//driver/interrupt:wire_interrupt_handler", + "//driver/registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "kernel_coherent_allocator", + srcs = ["kernel_coherent_allocator.cc"], + hdrs = ["kernel_coherent_allocator.h"], + deps = [ + ":linux_gasket_ioctl", + "//driver:hardware_structures", + "//driver/mmio:coherent_allocator", + "//port", + ], +) + +cc_library( + name = "kernel_registers", + srcs = ["kernel_registers.cc"], + hdrs = ["kernel_registers.h"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "//driver/registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "linux_gasket_ioctl", + hdrs = ["linux_gasket_ioctl.h"], +) diff --git a/driver/kernel/kernel_coherent_allocator.cc b/driver/kernel/kernel_coherent_allocator.cc new file mode 100644 index 0000000..21e42e1 --- /dev/null +++ b/driver/kernel/kernel_coherent_allocator.cc @@ -0,0 +1,129 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include // for PRI*64 +#include + +#include "driver/hardware_structures.h" +#include "driver/kernel/kernel_coherent_allocator.h" +#include "driver/kernel/linux_gasket_ioctl.h" +#include "port/cleanup.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/math_util.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelCoherentAllocator::KernelCoherentAllocator(const std::string &device_path, + int alignment_bytes, + size_t size_bytes) + : CoherentAllocator(alignment_bytes, size_bytes), + device_path_(device_path) {} + +util::StatusOr KernelCoherentAllocator::DoOpen(size_t size_bytes) { + if (fd_ != -1) { + return util::FailedPreconditionError("Device already open."); + } + + fd_ = open(device_path_.c_str(), O_RDWR); + if (fd_ < 0) { + return util::FailedPreconditionError( + StringPrintf("Device open failed : %d (%s)", fd_, strerror(errno))); + } + + auto fd_closer = MakeCleanup([this] { + close(fd_); + fd_ = -1; + }); + + // Enable the allocator and request the memory region + // Note: only one region is supported by kernel driver. + // The kernel coherent allocator returns zero'ed memory. + gasket_coherent_alloc_config_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + ioctl_buffer.page_table_index = 0; + ioctl_buffer.enable = 1; + ioctl_buffer.size = size_bytes; + if (ioctl(fd_, GASKET_IOCTL_CONFIG_COHERENT_ALLOCATOR, &ioctl_buffer) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Could not enable coherent allocator size %" PRIu64 ". : fd=%d (%s)", + ioctl_buffer.size, fd_, strerror(errno))); + } + + dma_address_ = ioctl_buffer.dma_address; + + // Map the memory range so as it can be accessed by user. + char *mem_base = + static_cast(mmap(nullptr, size_bytes, PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_LOCKED, fd_, dma_address_)); + if (mem_base == MAP_FAILED) { + // Release the memory block. + ioctl_buffer.page_table_index = 0; + ioctl_buffer.enable = 0; + ioctl_buffer.size = size_bytes; + if (ioctl(fd_, GASKET_IOCTL_CONFIG_COHERENT_ALLOCATOR, &ioctl_buffer) != + 0) { + VLOG(1) << StringPrintf("mmap_failed and couldn't free memory : %s.\n", + strerror(errno)); + } + return util::FailedPreconditionError( + StringPrintf("CoherentAllocator Could not mmap size %zu.", size_bytes)); + } + fd_closer.release(); + return mem_base; +} + +util::Status KernelCoherentAllocator::DoClose(char *mem_base, + size_t size_bytes) { + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + util::Status status; + + if (munmap(mem_base, size_bytes)) { + status.Update(util::FailedPreconditionError( + StringPrintf("Error unmapping coherent memory. %s", strerror(errno)))); + } + + // Release the memory block. + gasket_coherent_alloc_config_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + ioctl_buffer.page_table_index = 0; + ioctl_buffer.enable = 0; + ioctl_buffer.dma_address = dma_address_; + ioctl_buffer.size = size_bytes; + if (ioctl(fd_, GASKET_IOCTL_CONFIG_COHERENT_ALLOCATOR, &ioctl_buffer) != 0) { + status.Update(util::FailedPreconditionError(StringPrintf( + "Could not disable coherent allocator size %" PRIu64 ". : %d (%s)", + ioctl_buffer.size, fd_, strerror(errno)))); + return status; + } + + close(fd_); + fd_ = -1; + dma_address_ = 0; + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_coherent_allocator.h b/driver/kernel/kernel_coherent_allocator.h new file mode 100644 index 0000000..e29b5e6 --- /dev/null +++ b/driver/kernel/kernel_coherent_allocator.h @@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_COHERENT_ALLOCATOR_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_COHERENT_ALLOCATOR_H_ + +#include +#include +#include +#include +#include + +#include "driver/mmio/coherent_allocator.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/statusor.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Functions to allocate coherent memory that is DMA-able by a Darwinn device. +class KernelCoherentAllocator : public CoherentAllocator { + public: + KernelCoherentAllocator(const std::string &device_path, int alignment_bytes, + size_t size_bytes); + ~KernelCoherentAllocator() = default; + + private: + // Implements Open. + util::StatusOr DoOpen(size_t size_bytes) override; + + // Implements close. + util::Status DoClose(char *mem_base, size_t size_bytes) override; + + // File descriptor of the opened device. + int fd_{-1}; + + // Device specific DMA address of the coherent memory block. + uint64 dma_address_{0}; + + // Device path. + const std::string device_path_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_COHERENT_ALLOCATOR_H_ diff --git a/driver/kernel/kernel_event.cc b/driver/kernel/kernel_event.cc new file mode 100644 index 0000000..87bc442 --- /dev/null +++ b/driver/kernel/kernel_event.cc @@ -0,0 +1,98 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_event.h" + +#include +#include +#include +#include +#include +#include + +#include // NOLINT + +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelEvent::KernelEvent(int event_fd, Handler handler) : event_fd_(event_fd) { + std::thread event_thread(&KernelEvent::Monitor, this, event_fd, + std::move(handler)); + thread_ = std::move(event_thread); +} + +KernelEvent::~KernelEvent() { + // Mark as disabled. + { + StdMutexLock lock(&mutex_); + enabled_ = false; + } + + // Write a fake event to force read() to return. + uint64 num_events = 1; + int result = write(event_fd_, &num_events, sizeof(num_events)); + if (result != sizeof(num_events)) { + LOG(WARNING) << StringPrintf("event_fd=%d. Fake event write failed (%d).", + event_fd_, result); + } + + // Wait for thread to exit. + thread_.join(); +} + +bool KernelEvent::IsEnabled() const { + StdMutexLock lock(&mutex_); + return enabled_; +} + +void KernelEvent::Monitor(int event_fd, const Handler& handler) { + VLOG(5) << StringPrintf("event_fd=%d. Monitor thread begin.", event_fd); + TRACE_START_THREAD("KernelEventHandlerMonitor"); + + while (IsEnabled()) { + // Wait for events (blocking). + uint64_t num_events = 0; + int result = read(event_fd, &num_events, sizeof(num_events)); + if (result != sizeof(num_events)) { + LOG(WARNING) << StringPrintf("event_fd=%d. Read failed (%d).", event_fd, + result); + break; + } + + VLOG(5) << StringPrintf( + "event_fd=%d. Monitor thread got num_events=%" PRId64 ".", event_fd, + num_events); + if (IsEnabled()) { + for (int i = 0; i < num_events; ++i) { + handler(); + } + } + } + + VLOG(5) << StringPrintf("event_fd=%d. Monitor thread exit.", event_fd); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_event.h b/driver/kernel/kernel_event.h new file mode 100644 index 0000000..3e96274 --- /dev/null +++ b/driver/kernel/kernel_event.h @@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_EVENT_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_EVENT_H_ + +#include // NOLINT +#include // NOLINT + +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Monitors events generated through eventfd. The eventfd file +// descriptor passed through the constructor must already be open +// and associated with an event source. Monitoring starts +// on instance creation and stops on destroy. +class KernelEvent { + public: + using Handler = std::function; + + KernelEvent(int event_fd, Handler handler); + ~KernelEvent(); + + // This class is neither copyable nor movable. + KernelEvent(const KernelEvent&) = delete; + KernelEvent& operator=(const KernelEvent&) = delete; + + private: + // Monitors eventfd_. Runs on thread_. + void Monitor(int event_fd, const Handler& handler); + + // Convenience function to read |enabled_| with locks held. + bool IsEnabled() const LOCKS_EXCLUDED(mutex_); + + // Event fd. + const int event_fd_; + + // Mutex that guards enabled_; + mutable std::mutex mutex_; + + // Enabled if true. + bool enabled_ GUARDED_BY(mutex_){true}; + + // Thread for monitoring interrupts. + std::thread thread_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_EVENT_H_ diff --git a/driver/kernel/kernel_event_handler.cc b/driver/kernel/kernel_event_handler.cc new file mode 100644 index 0000000..602b51e --- /dev/null +++ b/driver/kernel/kernel_event_handler.cc @@ -0,0 +1,119 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_event_handler.h" + +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include + +#include "driver/kernel/kernel_event.h" +#include "driver/kernel/linux_gasket_ioctl.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelEventHandler::KernelEventHandler(const std::string& device_path, + int num_events) + : device_path_(device_path), num_events_(num_events) { + event_fds_.resize(num_events_, -1); + events_.resize(num_events_); +} + +util::Status KernelEventHandler::Open() { + StdMutexLock lock(&mutex_); + if (fd_ != -1) { + return util::FailedPreconditionError("Device already open."); + } + + fd_ = open(device_path_.c_str(), O_RDWR); + if (fd_ < 0) { + return util::FailedPreconditionError( + StringPrintf("Device open failed : %d (%s)", fd_, strerror(errno))); + } + + for (int i = 0; i < num_events_; ++i) { + event_fds_[i] = eventfd(0, EFD_CLOEXEC); + events_[i].reset(); + } + + return util::Status(); // OK +} + +util::Status KernelEventHandler::Close() { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + for (int i = 0; i < num_events_; ++i) { + events_[i].reset(); + close(event_fds_[i]); + } + + close(fd_); + fd_ = -1; + + return util::Status(); // OK +} + +util::Status KernelEventHandler::SetEventFd(int event_fd, int event_id) const { + gasket_interrupt_eventfd interrupt; + interrupt.interrupt = event_id; + interrupt.event_fd = event_fd; + if (ioctl(fd_, GASKET_IOCTL_SET_EVENTFD, &interrupt) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Setting Event Fd Failed : %d (%s)", fd_, strerror(errno))); + } + + VLOG(5) << StringPrintf("Set event fd : event_id:%d -> event_fd:%d, ", + event_id, event_fd); + + return util::Status(); // OK +} + +util::Status KernelEventHandler::RegisterEvent(int event_id, + KernelEvent::Handler handler) { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + RETURN_IF_ERROR(SetEventFd(event_fds_[event_id], event_id)); + + // Enable events. + events_[event_id] = + gtl::MakeUnique(event_fds_[event_id], std::move(handler)); + + return util::Status(); // OK; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_event_handler.h b/driver/kernel/kernel_event_handler.h new file mode 100644 index 0000000..86e1479 --- /dev/null +++ b/driver/kernel/kernel_event_handler.h @@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_EVENT_HANDLER_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_EVENT_HANDLER_H_ + +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "driver/kernel/kernel_event.h" +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Implements a mechanism for processing kernel events. +class KernelEventHandler { + public: + KernelEventHandler(const std::string& device_path, int num_events); + ~KernelEventHandler() = default; + + util::Status Open() LOCKS_EXCLUDED(mutex_); + util::Status Close() LOCKS_EXCLUDED(mutex_); + + // Registers and enables the specified event. + util::Status RegisterEvent(int event_id, KernelEvent::Handler handler) + LOCKS_EXCLUDED(mutex_); + + private: + // Maps the specified event number with the specified id. + util::Status SetEventFd(int event_fd, int event_id) const + SHARED_LOCKS_REQUIRED(mutex_); + + // Device path. + const std::string device_path_; + + // Number of events. + const int num_events_; + + // Mutex that guards fd_, event_fd_, interrupts_; + mutable std::mutex mutex_; + + // File descriptor of the opened device. + int fd_ GUARDED_BY(mutex_){-1}; + + // Event FD list. + std::vector event_fds_ GUARDED_BY(mutex_); + + // Registered events. + std::vector> events_ GUARDED_BY(mutex_); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_EVENT_HANDLER_H_ diff --git a/driver/kernel/kernel_interrupt_handler.cc b/driver/kernel/kernel_interrupt_handler.cc new file mode 100644 index 0000000..70843a2 --- /dev/null +++ b/driver/kernel/kernel_interrupt_handler.cc @@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_interrupt_handler.h" + +#include +#include + +#include "driver/interrupt/interrupt_handler.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelInterruptHandler::KernelInterruptHandler(const std::string& device_path) + : event_handler_(device_path, DW_INTERRUPT_COUNT) {} + +util::Status KernelInterruptHandler::Open() { return event_handler_.Open(); } + +util::Status KernelInterruptHandler::Close(bool in_error) { + return event_handler_.Close(); +} + +util::Status KernelInterruptHandler::Register(Interrupt interrupt, + Handler handler) { + return event_handler_.RegisterEvent(interrupt, std::move(handler)); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_interrupt_handler.h b/driver/kernel/kernel_interrupt_handler.h new file mode 100644 index 0000000..0a09dc5 --- /dev/null +++ b/driver/kernel/kernel_interrupt_handler.h @@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_INTERRUPT_HANDLER_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_INTERRUPT_HANDLER_H_ + +#include + +#include "driver/interrupt/interrupt_handler.h" +#include "driver/kernel/kernel_event_handler.h" +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Kernel implementation of the interrupt handler interface. +class KernelInterruptHandler : public InterruptHandler { + public: + // Default close to avoid name hiding. + using InterruptHandler::Close; + + explicit KernelInterruptHandler(const std::string& device_path); + ~KernelInterruptHandler() override = default; + + util::Status Open() override; + util::Status Close(bool in_error) override; + util::Status Register(Interrupt interrupt, Handler handler) override; + + private: + // Backing event handler. + KernelEventHandler event_handler_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_INTERRUPT_HANDLER_H_ diff --git a/driver/kernel/kernel_mmu_mapper.cc b/driver/kernel/kernel_mmu_mapper.cc new file mode 100644 index 0000000..9a8e4cb --- /dev/null +++ b/driver/kernel/kernel_mmu_mapper.cc @@ -0,0 +1,236 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_mmu_mapper.h" + +#include +#include +#include +#include +#include +#include + +#include // for PRI*64 +#include + +#include "driver/hardware_structures.h" +#include "driver/kernel/linux_gasket_ioctl.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelMmuMapper::KernelMmuMapper(const std::string &device_path) + : device_path_(device_path) {} + +util::Status KernelMmuMapper::Open( + int num_simple_page_table_entries_requested) { + StdMutexLock lock(&mutex_); + if (fd_ != -1) { + return util::FailedPreconditionError("Device already open."); + } + + fd_ = open(device_path_.c_str(), O_RDWR); + if (fd_ < 0) { + return util::FailedPreconditionError( + StringPrintf("Device open failed : %d (%s)", fd_, strerror(errno))); + } + + gasket_page_table_ioctl ioctl_buffer; + memset(&ioctl_buffer, 0, sizeof(ioctl_buffer)); + ioctl_buffer.page_table_index = 0; + ioctl_buffer.size = num_simple_page_table_entries_requested; + if (ioctl(fd_, GASKET_IOCTL_PARTITION_PAGE_TABLE, &ioctl_buffer) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Could not partition page table. : %d (%s)", fd_, strerror(errno))); + } + + return util::Status(); // OK +} + +util::Status KernelMmuMapper::Close() { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + close(fd_); + fd_ = -1; + + return util::Status(); // OK +} + +// Converts DmaDirection to +// gasket_page_table_ioctl_flags.flags.DMA_DIRECTION flag. +static uint32_t DirectionFlag(DmaDirection direction) { + switch (direction) { + case DmaDirection::kBidirectional: + return DMA_BIDIRECTIONAL; + case DmaDirection::kToDevice: + return DMA_TO_DEVICE; + case DmaDirection::kFromDevice: + return DMA_FROM_DEVICE; + } +} + +util::Status KernelMmuMapper::DoMap(const void *buffer, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) { + TRACE_SCOPE("KernelMmuMapper::DoMap"); + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + gasket_page_table_ioctl_flags buffer_to_map; + memset(&buffer_to_map, 0, sizeof(buffer_to_map)); + buffer_to_map.base.page_table_index = 0; + buffer_to_map.base.host_address = reinterpret_cast(buffer); + buffer_to_map.base.size = num_pages * kHostPageSize; + buffer_to_map.base.device_address = device_virtual_address; + buffer_to_map.flags = DirectionFlag(direction) + << GASKET_PT_FLAGS_DMA_DIRECTION_SHIFT; + + int ioctl_retval; + if (map_flags_supported_) { + ioctl_retval = ioctl(fd_, GASKET_IOCTL_MAP_BUFFER_FLAGS, &buffer_to_map); + if (ioctl_retval == -EPERM || ioctl_retval == -ENOTTY || + ioctl_retval == -EINVAL) { + VLOG(4) << StringPrintf("Failed to map buffer with flags, error %d", + ioctl_retval); + // This corresponds to an old kernel which doesn't yet support flags. + // Set member variable to fallback to legacy IOCTL and try again. + map_flags_supported_ = false; + } + } + + if (!map_flags_supported_) { + ioctl_retval = ioctl(fd_, GASKET_IOCTL_MAP_BUFFER, &buffer_to_map.base); + } + + if (ioctl_retval != 0) { + return util::FailedPreconditionError( + StringPrintf("Could not map pages : %d (%s)", fd_, strerror(errno))); + } + + if (map_flags_supported_) { + VLOG(4) << StringPrintf("MmuMapper#Map() : %016" PRIx64 " -> %016" PRIx64 + " (%d pages) flags=%08" PRIx32 ".", + buffer_to_map.base.host_address, + buffer_to_map.base.device_address, num_pages, + buffer_to_map.flags); + } else { + VLOG(4) << StringPrintf("MmuMapper#Map() : %016" PRIx64 " -> %016" PRIx64 + " (%d pages).", + buffer_to_map.base.host_address, + buffer_to_map.base.device_address, num_pages); + } + + return util::Status(); // OK +} + +util::Status KernelMmuMapper::DoUnmap(const void *buffer, int num_pages, + uint64 device_virtual_address) { + TRACE_SCOPE("KernelMmuMapper::DoUnmap"); + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + gasket_page_table_ioctl buffer_to_map; + memset(&buffer_to_map, 0, sizeof(buffer_to_map)); + buffer_to_map.page_table_index = 0; + buffer_to_map.host_address = reinterpret_cast(buffer); + buffer_to_map.size = num_pages * kHostPageSize; + buffer_to_map.device_address = device_virtual_address; + if (ioctl(fd_, GASKET_IOCTL_UNMAP_BUFFER, &buffer_to_map) != 0) { + return util::FailedPreconditionError( + StringPrintf("Could not unmap pages : %d (%s)", fd_, strerror(errno))); + } + + VLOG(4) << StringPrintf( + "MmuMaper#Unmap() : %016" PRIx64 " -> %016" PRIx64 " (%d pages).", + buffer_to_map.host_address, buffer_to_map.device_address, num_pages); + + return util::Status(); // OK +} + +util::Status KernelMmuMapper::DoMap(int fd, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) { + TRACE_SCOPE("KernelMmuMapper::DoMap"); + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + gasket_page_table_ioctl_dmabuf buffer_to_map = {0}; + buffer_to_map.map = 1; + buffer_to_map.page_table_index = 0; + buffer_to_map.num_pages = num_pages; + buffer_to_map.dmabuf_fd = fd; + buffer_to_map.device_address = device_virtual_address; + buffer_to_map.flags = DirectionFlag(direction) + << GASKET_PT_FLAGS_DMA_DIRECTION_SHIFT; + + int ioctl_retval = ioctl(fd_, GASKET_IOCTL_MAP_DMABUF, &buffer_to_map); + if (ioctl_retval != 0) { + return util::FailedPreconditionError( + StringPrintf("Could not map pages : %d (%s)", fd_, strerror(errno))); + } + + VLOG(4) << StringPrintf("MmuMapper#Map() : fd %d -> %016" PRIx64 + " (%d pages) flags=%08" PRIx32 ".", + buffer_to_map.dmabuf_fd, buffer_to_map.device_address, + num_pages, buffer_to_map.flags); + + return util::OkStatus(); +} + +util::Status KernelMmuMapper::DoUnmap(int fd, int num_pages, + uint64 device_virtual_address) { + TRACE_SCOPE("KernelMmuMapper::DoUnmap"); + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + gasket_page_table_ioctl_dmabuf buffer_to_unmap = {0}; + buffer_to_unmap.map = 0; + buffer_to_unmap.page_table_index = 0; + buffer_to_unmap.num_pages = num_pages; + buffer_to_unmap.dmabuf_fd = fd; + buffer_to_unmap.device_address = device_virtual_address; + if (ioctl(fd_, GASKET_IOCTL_MAP_DMABUF, &buffer_to_unmap) != 0) { + return util::FailedPreconditionError( + StringPrintf("Could not unmap pages : %d (%s)", fd_, strerror(errno))); + } + + VLOG(4) << StringPrintf( + "MmuMaper#Unmap() : fd %d -> %016" PRIx64 " (%d pages).", + buffer_to_unmap.dmabuf_fd, buffer_to_unmap.device_address, num_pages); + + return util::OkStatus(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_mmu_mapper.h b/driver/kernel/kernel_mmu_mapper.h new file mode 100644 index 0000000..02d790a --- /dev/null +++ b/driver/kernel/kernel_mmu_mapper.h @@ -0,0 +1,93 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_MMU_MAPPER_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_MMU_MAPPER_H_ + +#include +#include // NOLINT + +#include "driver/memory/dma_direction.h" +#include "driver/memory/mmu_mapper.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Kernel implementation of the MMU mapper interface. +class KernelMmuMapper : public MmuMapper { + public: + explicit KernelMmuMapper(const std::string &device_path); + ~KernelMmuMapper() override = default; + + // Overrides from mmu_mapper.h + util::Status Open(int num_simple_page_table_entries_requested) override; + util::Status Close() override; + + protected: + util::Status DoMap(const void *buffer, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) override; + util::Status DoUnmap(const void *buffer, int num_pages, + uint64 device_virtual_address) override; + util::Status DoMap(int fd, int num_pages, uint64 device_virtual_address, + DmaDirection direction) override; + util::Status DoUnmap(int fd, int num_pages, + uint64 device_virtual_address) override; + + // Calls ioctl on the device file descriptor owned by this instance. + // Forwards the parameters to the ioctl too; returns -1 on closed device and + // the return value of ioctl otherwise. + template + int DoIoctl(Params &&... params) { + // TODO : At this moment this mutex is there to guard uses of fd, + // but if later we find there's a lot of concurrent threaded map/unmap + // activity we could consider trying to allow them to run in parallel. If + // this macro is to be called on ioctls that aren't necessarily + // mutually-exclusive / can run in parallel (at the runtime level) + // then may want to make appropriate locking the caller's responsibility. + + StdMutexLock lock(&mutex_); + if (fd_ != -1) { + return ioctl(fd_, std::forward(params)...); + } else { + VLOG(4) << "Invalid file descriptor."; + return -1; + } + } + + private: + // Device path. + const std::string device_path_; + + // File descriptor of the opened device. + int fd_ GUARDED_BY(mutex_){-1}; + + // Mutex that guards fd_; + mutable std::mutex mutex_; + + // Indicates whether the kernel driver supports GASKET_IOCTL_MAP_BUFFER_FLAGS. + bool map_flags_supported_ GUARDED_BY(mutex_){true}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_MMU_MAPPER_H_ diff --git a/driver/kernel/kernel_registers.cc b/driver/kernel/kernel_registers.cc new file mode 100644 index 0000000..7baf9e0 --- /dev/null +++ b/driver/kernel/kernel_registers.cc @@ -0,0 +1,273 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_registers.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_format.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelRegisters::KernelRegisters(const std::string& device_path, + const std::vector& mmap_region, + bool read_only) + : device_path_(device_path), read_only_(read_only) { + for (const auto& region : mmap_region) { + mmap_region_.push_back({region.offset, region.size, nullptr}); + } +} + +KernelRegisters::KernelRegisters(const std::string& device_path, + uint64 mmap_offset, uint64 mmap_size, + bool read_only) + : KernelRegisters(device_path, {{mmap_offset, mmap_size}}, read_only) {} + +KernelRegisters::~KernelRegisters() { + for (auto& region : mmap_region_) { + if (region.registers != nullptr) { + const int ret = munmap(region.registers, region.size); + if (ret != 0) { + LOG(ERROR) << "Error unmapping registers: " << strerror(errno); + } + region.registers = nullptr; + } + } + + if (fd_ != -1) { + LOG(WARNING) + << "Destroying KernelRegisters - Close() had not yet been called!"; + util::Status status = Close(); + if (!status.ok()) { + LOG(ERROR) << status; + } + } +} + +util::Status KernelRegisters::Open() { + StdMutexLock lock(&mutex_); + if (fd_ != -1) { + return util::FailedPreconditionError("Device already open."); + } + + VLOG(1) << StringPrintf("Opening %s. read_only=%d", device_path_.c_str(), + read_only_); + int mode = O_RDWR; + if (read_only_) { + mode = O_RDONLY; + } + + fd_ = open(device_path_.c_str(), mode); + if (fd_ < 0) { + return util::FailedPreconditionError( + StringPrintf("Device open failed : %d (%s)", fd_, strerror(errno))); + } + + int protections = PROT_READ | PROT_WRITE; + if (read_only_) { + protections = PROT_READ; + } + + for (auto& region : mmap_region_) { + VLOG(1) << StringPrintf("mmap_offset=0x%016llx, mmap_size=%lld", + static_cast(region.offset), + static_cast(region.size)); + + region.registers = + static_cast(mmap(nullptr, region.size, protections, + MAP_SHARED | MAP_LOCKED, fd_, region.offset)); + if (region.registers == MAP_FAILED) { + close(fd_); + fd_ = -1; + region.registers = nullptr; + + return util::FailedPreconditionError( + StringPrintf("Could not mmap: %s.", strerror(errno))); + } + VLOG(3) << "Got map addr at 0x" << std::hex << region.registers; + } + + return util::Status(); // OK +} + +util::Status KernelRegisters::Close() { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + + for (auto& region : mmap_region_) { + if (region.registers != nullptr) { + VLOG(1) << StringPrintf( + "Closing %s. mmap_offset=0x%016llx, mmap_size=%lld, read_only=%d", + device_path_.c_str(), + static_cast(region.offset), // NOLINT(runtime/int) + static_cast(region.size), // NOLINT(runtime/int) + read_only_); + const int ret = munmap(region.registers, region.size); + if (ret != 0) { + LOG(ERROR) << "Error unmapping registers: " << strerror(errno); + } + region.registers = nullptr; + } + } + + close(fd_); + fd_ = -1; + + return util::Status(); // OK +} + +util::StatusOr KernelRegisters::LockAndGetMappedOffset( + uint64 offset, int alignment) const { + StdMutexLock lock(&mutex_); + return GetMappedOffset(offset, alignment); +} + +inline util::StatusOr KernelRegisters::GetMappedOffset( + uint64 offset, int alignment) const { + const size_t end_of_region = offset + alignment; + + if (end_of_region < offset) { + return util::OutOfRangeError( + StringPrintf("Offset (0x%016llx) + size_bytes is larger than 64-bit", + static_cast(offset))); + } + + for (const auto& region : mmap_region_) { + if ((offset >= region.offset) && + ((end_of_region - region.offset) <= region.size)) { + if (region.registers != nullptr) { + return reinterpret_cast(region.registers) + offset - + region.offset; + } else { + return util::InternalError("Region not mapped yet"); + } + } + } + + return util::OutOfRangeError(absl::StrFormat( + "Offset (0x%016llx) is not covered by any region", offset)); +} + +util::Status KernelRegisters::Write(uint64 offset, uint64 value) { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + if (read_only_) { + return util::FailedPreconditionError("Read only, cannot write."); + } + if (offset % sizeof(uint64) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Offset (0x%016llx) not aligned to 8B", + static_cast(offset))); // NOLINT(runtime/int) + } + + ASSIGN_OR_RETURN(auto mmap_register, GetMappedOffset(offset, sizeof(uint64))); + *reinterpret_cast(mmap_register) = value; + VLOG(5) << StringPrintf( + "Write: offset = 0x%016llx, value = 0x%016llx", + static_cast(offset), // NOLINT(runtime/int) + static_cast(value)); // NOLINT(runtime/int) + + return util::Status(); // OK +} + +util::StatusOr KernelRegisters::Read(uint64 offset) { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + if (offset % sizeof(uint64) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Offset (0x%016llx) not aligned to 8B", + static_cast(offset))); // NOLINT(runtime/int) + } + + ASSIGN_OR_RETURN(auto mmap_register, GetMappedOffset(offset, sizeof(uint64))); + uint64 value = *reinterpret_cast(mmap_register); + VLOG(5) << StringPrintf( + "Read: offset = 0x%016llx, value: = 0x%016llx", + static_cast(offset), // NOLINT(runtime/int) + static_cast(value)); // NOLINT(runtime/int) + + return value; +} + +util::Status KernelRegisters::Write32(uint64 offset, uint32 value) { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + if (read_only_) { + return util::FailedPreconditionError("Read only, cannot write."); + } + if (offset % sizeof(uint32) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Offset (0x%016llx) not aligned to 8B", + static_cast(offset))); // NOLINT(runtime/int) + } + + ASSIGN_OR_RETURN(auto mmap_register, GetMappedOffset(offset, sizeof(uint32))); + *reinterpret_cast(mmap_register) = value; + VLOG(5) << StringPrintf( + "Write: offset = 0x%016llx, value = 0x%08x", + static_cast(offset), // NOLINT(runtime/int) + value); + + return util::Status(); // OK +} + +util::StatusOr KernelRegisters::Read32(uint64 offset) { + StdMutexLock lock(&mutex_); + if (fd_ == -1) { + return util::FailedPreconditionError("Device not open."); + } + if (offset % sizeof(uint32) != 0) { + return util::FailedPreconditionError(StringPrintf( + "Offset (0x%016llx) not aligned to 8B", + static_cast(offset))); // NOLINT(runtime/int) + } + + ASSIGN_OR_RETURN(auto mmap_register, GetMappedOffset(offset, sizeof(uint32))); + uint32 value = *reinterpret_cast(mmap_register); + VLOG(5) << StringPrintf( + "Read: offset = 0x%016llx, value: = 0x%08x", + static_cast(offset), // NOLINT(runtime/int) + value); + + return value; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_registers.h b/driver/kernel/kernel_registers.h new file mode 100644 index 0000000..4e95ea7 --- /dev/null +++ b/driver/kernel/kernel_registers.h @@ -0,0 +1,98 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_REGISTERS_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_REGISTERS_H_ + +#include // NOLINT +#include +#include + +#include "driver/registers/registers.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Kernel implementation of the register interface. +class KernelRegisters : public Registers { + public: + struct MmapRegion { + uint64 offset; + uint64 size; + }; + + KernelRegisters(const std::string& device_path, + const std::vector& mmap_region, bool read_only); + + KernelRegisters(const std::string& device_path, uint64 mmap_offset, + uint64 mmap_size, bool read_only); + + ~KernelRegisters() override; + + // Overrides from registers.h + util::Status Open() LOCKS_EXCLUDED(mutex_) override; + util::Status Close() LOCKS_EXCLUDED(mutex_) override; + util::Status Write(uint64 offset, uint64 value) + LOCKS_EXCLUDED(mutex_) override; + util::StatusOr Read(uint64 offset) LOCKS_EXCLUDED(mutex_) override; + util::Status Write32(uint64 offset, uint32 value) + LOCKS_EXCLUDED(mutex_) override; + util::StatusOr Read32(uint64 offset) LOCKS_EXCLUDED(mutex_) override; + + protected: + struct MappedRegisterRegion { + uint64 offset; + uint64 size; + uint64* registers; + }; + + // Acquires the lock and maps CSR offset. + util::StatusOr LockAndGetMappedOffset(uint64 offset, + int alignment) const + LOCKS_EXCLUDED(mutex_); + + // Returns the reference to the mapped regions. + std::vector& GetMmapRegion() { return mmap_region_; } + + private: + // Maps CSR offset to virtual address without acquiring the lock. + util::StatusOr GetMappedOffset(uint64 offset, int alignment) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Device path. + const std::string device_path_; + + // mmap() region. + std::vector mmap_region_ GUARDED_BY(mutex_); + + // true, if read only. false otherwise. + const bool read_only_; + + // File descriptor of the opened device. + int fd_ GUARDED_BY(mutex_){-1}; + + // Mutex that guards fd_; + mutable std::mutex mutex_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_REGISTERS_H_ diff --git a/driver/kernel/kernel_wire_interrupt_handler.cc b/driver/kernel/kernel_wire_interrupt_handler.cc new file mode 100644 index 0000000..cf40370 --- /dev/null +++ b/driver/kernel/kernel_wire_interrupt_handler.cc @@ -0,0 +1,78 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/kernel/kernel_wire_interrupt_handler.h" + +#include + +#include "driver/config/wire_csr_offsets.h" +#include "driver/interrupt/wire_interrupt_handler.h" +#include "driver/registers/registers.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +KernelWireInterruptHandler::KernelWireInterruptHandler( + Registers* registers, const config::WireCsrOffsets& wire_csr_offsets, + const std::string& device_path, int num_wires) + : wire_handler_(registers, wire_csr_offsets, num_wires), + event_handler_(device_path, num_wires), + num_wires_(num_wires) {} + +util::Status KernelWireInterruptHandler::Open() { + RETURN_IF_ERROR(wire_handler_.Open()); + auto wire_handler_closer = MakeCleanup( + [this]() NO_THREAD_SAFETY_ANALYSIS { CHECK_OK(wire_handler_.Close()); }); + + RETURN_IF_ERROR(event_handler_.Open()); + auto event_handler_closer = MakeCleanup( + [this]() NO_THREAD_SAFETY_ANALYSIS { CHECK_OK(event_handler_.Close()); }); + + for (int wire = 0; wire < num_wires_; ++wire) { + RETURN_IF_ERROR(event_handler_.RegisterEvent(wire, [this, wire]() { + wire_handler_.InvokeAllPendingInterrupts(wire); + })); + } + + // All good. Release cleanup functions. + wire_handler_closer.release(); + event_handler_closer.release(); + + return util::Status(); // OK +} + +util::Status KernelWireInterruptHandler::Close(bool in_error) { + util::Status status; + status.Update(event_handler_.Close()); + status.Update(wire_handler_.Close()); + return status; +} + +util::Status KernelWireInterruptHandler::Register(Interrupt interrupt, + Handler handler) { + return wire_handler_.Register(interrupt, std::move(handler)); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/kernel/kernel_wire_interrupt_handler.h b/driver/kernel/kernel_wire_interrupt_handler.h new file mode 100644 index 0000000..718efc1 --- /dev/null +++ b/driver/kernel/kernel_wire_interrupt_handler.h @@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_KERNEL_KERNEL_WIRE_INTERRUPT_HANDLER_H_ +#define DARWINN_DRIVER_KERNEL_KERNEL_WIRE_INTERRUPT_HANDLER_H_ + +#include // NOLINT +#include + +#include "driver/config/wire_csr_offsets.h" +#include "driver/interrupt/interrupt_handler.h" +#include "driver/interrupt/wire_interrupt_handler.h" +#include "driver/kernel/kernel_event_handler.h" +#include "driver/registers/registers.h" +#include "port/status.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Wire Interrupt handler implementation that reads and processes the pending +// bit array on a single wire interrupt in userspace. +class KernelWireInterruptHandler : public InterruptHandler { + public: + // Default close to avoid name hiding. + using InterruptHandler::Close; + + KernelWireInterruptHandler(Registers* registers, + const config::WireCsrOffsets& wire_csr_offsets, + const std::string& device_path, int num_wires); + ~KernelWireInterruptHandler() override = default; + + // This class is neither copyable nor movable. + KernelWireInterruptHandler(const KernelWireInterruptHandler&) = delete; + KernelWireInterruptHandler& operator=(const KernelWireInterruptHandler&) = + delete; + + util::Status Open() override; + util::Status Close(bool in_error) override; + util::Status Register(Interrupt interrupt, Handler handler) override; + + private: + // Backing wire interrupt handler. + WireInterruptHandler wire_handler_; + + // KernelEventHandler + KernelEventHandler event_handler_; + + // Number of wires. + const int num_wires_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_KERNEL_KERNEL_WIRE_INTERRUPT_HANDLER_H_ diff --git a/driver/kernel/linux_gasket_ioctl.h b/driver/kernel/linux_gasket_ioctl.h new file mode 100644 index 0000000..2ce6f19 --- /dev/null +++ b/driver/kernel/linux_gasket_ioctl.h @@ -0,0 +1,202 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Common Gasket device kernel and user space declarations. */ +#ifndef __LINUX_GASKET_IOCTL_H__ +#define __LINUX_GASKET_IOCTL_H__ + +#include +#include + +#ifndef __KERNEL__ +#include +#endif + +/* ioctl structure declarations */ + +/* Ioctl structures are padded to a multiple of 64 bits */ +/* and padded to put 64 bit values on 64 bit boundaries. */ +/* Unsigned 64 bit integers are used to hold pointers. */ +/* This helps compatibility between 32 and 64 bits. */ + +/* + * Common structure for ioctls associating an eventfd with a device interrupt, + * when using the Gasket interrupt module. + */ +struct gasket_interrupt_eventfd { + uint64_t interrupt; + uint64_t event_fd; +}; + +/* + * Common structure for ioctls mapping and unmapping buffers when using the + * Gasket page_table module. + */ +struct gasket_page_table_ioctl { + uint64_t page_table_index; + uint64_t size; + uint64_t host_address; + uint64_t device_address; +}; + +/* + * Definitions for gasket_page_table_ioctl_flags.flags bitfield. + */ +#define GASKET_PT_FLAGS_DMA_DIRECTION_SHIFT 1 +#define GASKET_PT_FLAGS_DMA_DIRECTION_WIDTH 2 + +/* + * Value for gasket_page_table_ioctl_flags.flags.DMA_DIRECTION. + * Mirrors kernel dma_data_direction definition in dma-direction.h. + */ +enum dma_data_direction { + DMA_BIDIRECTIONAL = 0, + DMA_TO_DEVICE = 1, + DMA_FROM_DEVICE = 2, + DMA_NONE = 3, +}; + +/* + * Structure for ioctl mapping buffers with flags when using the Gasket + * page_table module. + */ +struct gasket_page_table_ioctl_flags { + struct gasket_page_table_ioctl base; + /* + * Flags indicating status and attribute requests from the host. + * NOTE: Set RESERVED bits to 0 to ensure backwards compatibility. + * + * Bitfields: + * [0] - RESERVED + * [2:1] - DMA_DIRECTION: dma_data_direction requested by host + * [31:3] - RESERVED + */ + uint32_t flags; +}; + +/* + * Common structure for ioctls mapping and unmapping buffers when using the + * Gasket page_table module. + * dma_address: phys addr start of coherent memory, allocated by kernel + */ +struct gasket_coherent_alloc_config_ioctl { + uint64_t page_table_index; + uint64_t enable; + uint64_t size; + uint64_t dma_address; +}; + +/* + * Common structure for ioctls mapping and unmapping dma-bufs when using the + * Gasket page_table module. + * map: boolean, non-zero to map, 0 to unmap. + * flags: see gasket_page_table_ioctl_flags.flags. + */ +struct gasket_page_table_ioctl_dmabuf { + uint64_t page_table_index; + uint64_t device_address; + int dmabuf_fd; + uint32_t num_pages; + uint32_t map; + uint32_t flags; +}; + +/* Base number for all Gasket-common IOCTLs */ +#define GASKET_IOCTL_BASE 0xDC + +/* Reset the device. */ +// NOLINTNEXTLINE +#define GASKET_IOCTL_RESET _IO(GASKET_IOCTL_BASE, 0) + +/* Associate the specified [event]fd with the specified interrupt. */ +#define GASKET_IOCTL_SET_EVENTFD \ + _IOW(GASKET_IOCTL_BASE, 1, struct gasket_interrupt_eventfd) + +/* + * Clears any eventfd associated with the specified interrupt. The (ulong) + * argument is the interrupt number to clear. + */ +// NOLINTNEXTLINE +#define GASKET_IOCTL_CLEAR_EVENTFD _IOW(GASKET_IOCTL_BASE, 2, unsigned long) + +/* + * [Loopbacks only] Requests that the loopback device send the specified + * interrupt to the host. The (ulong) argument is the number of the interrupt to + * send. + */ +#define GASKET_IOCTL_LOOPBACK_INTERRUPT \ + _IOW(GASKET_IOCTL_BASE, 3, unsigned long) // NOLINT + +/* Queries the kernel for the number of page tables supported by the device. */ +#define GASKET_IOCTL_NUMBER_PAGE_TABLES _IOR(GASKET_IOCTL_BASE, 4, uint64_t) + +/* + * Queries the kernel for the maximum size of the page table. Only the size and + * page_table_index fields are used from the struct gasket_page_table_ioctl. + */ +#define GASKET_IOCTL_PAGE_TABLE_SIZE \ + _IOWR(GASKET_IOCTL_BASE, 5, struct gasket_page_table_ioctl) + +/* + * Queries the kernel for the current simple page table size. Only the size and + * page_table_index fields are used from the struct gasket_page_table_ioctl. + */ +#define GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE \ + _IOWR(GASKET_IOCTL_BASE, 6, struct gasket_page_table_ioctl) + +/* + * Tells the kernel to change the split between the number of simple and + * extended entries in the given page table. Only the size and page_table_index + * fields are used from the struct gasket_page_table_ioctl. + */ +#define GASKET_IOCTL_PARTITION_PAGE_TABLE \ + _IOW(GASKET_IOCTL_BASE, 7, struct gasket_page_table_ioctl) + +/* + * Tells the kernel to map size bytes at host_address to device_address in + * page_table_index page table. + */ +#define GASKET_IOCTL_MAP_BUFFER \ + _IOW(GASKET_IOCTL_BASE, 8, struct gasket_page_table_ioctl) + +/* + * Tells the kernel to unmap size bytes at host_address from device_address in + * page_table_index page table. + */ +#define GASKET_IOCTL_UNMAP_BUFFER \ + _IOW(GASKET_IOCTL_BASE, 9, struct gasket_page_table_ioctl) + +/* Clear the interrupt counts stored for this device. */ +#define GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS _IO(GASKET_IOCTL_BASE, 10) + +/* Enable/Disable and configure the coherent allocator. */ +#define GASKET_IOCTL_CONFIG_COHERENT_ALLOCATOR \ + _IOWR(GASKET_IOCTL_BASE, 11, struct gasket_coherent_alloc_config_ioctl) + +/* + * Tells the kernel to map size bytes at host_address to device_address in + * page_table_index page table. Passes flags to indicate additional attribute + * requests for the mapped memory. + */ +#define GASKET_IOCTL_MAP_BUFFER_FLAGS \ + _IOW(GASKET_IOCTL_BASE, 12, struct gasket_page_table_ioctl_flags) + +/* + * Tells the kernel to map/unmap dma-buf with fd to device_address in + * page_table_index page table. + */ +#define GASKET_IOCTL_MAP_DMABUF \ + _IOW(GASKET_IOCTL_BASE, 13, struct gasket_page_table_ioctl_dmabuf) + +#endif /* __LINUX_GASKET_IOCTL_H__ */ diff --git a/driver/libdarwinn_driver.lds b/driver/libdarwinn_driver.lds new file mode 100644 index 0000000..68d3d26 --- /dev/null +++ b/driver/libdarwinn_driver.lds @@ -0,0 +1,11 @@ +VER_1.0 { + global: + extern "C++" { + platforms::darwinn::api::*; + platforms::darwinn::internal::*; + platforms::darwinn::util::*; + platforms::darwinn::Buffer::*; + }; + local: + *; +}; diff --git a/driver/memory/BUILD b/driver/memory/BUILD new file mode 100644 index 0000000..603852d --- /dev/null +++ b/driver/memory/BUILD @@ -0,0 +1,188 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Memory management related functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "address_utilities", + hdrs = ["address_utilities.h"], + deps = [ + "//driver:hardware_structures", + "//port", + ], +) + +cc_library( + name = "mmu_mapper", + srcs = ["mmu_mapper.cc"], + hdrs = ["mmu_mapper.h"], + deps = [ + ":address_utilities", + "//api:buffer", + "//driver:device_buffer", + "//driver/memory:dma_direction", + "//port", + "//port:tracing", + ], +) + +cc_library( + name = "address_space", + hdrs = ["address_space.h"], + deps = [ + ":dma_direction", + "//api:buffer", + "//driver:device_buffer", + "//port", + ], +) + +cc_library( + name = "mmio_address_space", + srcs = ["mmio_address_space.cc"], + hdrs = ["mmio_address_space.h"], + deps = [ + ":address_space", + ":address_utilities", + ":dma_direction", + ":mmu_mapper", + "//api:buffer", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "buddy_address_space", + srcs = ["buddy_address_space.cc"], + hdrs = ["buddy_address_space.h"], + deps = [ + ":address_utilities", + ":buddy_allocator", + ":dma_direction", + ":mmio_address_space", + ":mmu_mapper", + "//api:buffer", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) + +cc_library( + name = "address_space_allocator", + hdrs = ["address_space_allocator.h"], + deps = ["//port"], +) + +cc_library( + name = "buddy_allocator", + srcs = ["buddy_allocator.cc"], + hdrs = ["buddy_allocator.h"], + deps = [ + ":address_space_allocator", + ":address_utilities", + "@com_google_absl//absl/strings:str_format", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "fake_mmu_mapper", + srcs = ["fake_mmu_mapper.cc"], + hdrs = ["fake_mmu_mapper.h"], + deps = [ + ":address_utilities", + ":dma_direction", + ":mmu_mapper", + "//driver:hardware_structures", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "dma_direction", + hdrs = ["dma_direction.h"], + deps = [ + ], +) + +cc_library( + name = "nop_address_space", + srcs = ["nop_address_space.cc"], + hdrs = ["nop_address_space.h"], + deps = [ + ":address_space", + ":dma_direction", + "//api:buffer", + "//driver:device_buffer", + "//port", + ], +) + +cc_library( + name = "dual_address_space", + srcs = ["dual_address_space.cc"], + hdrs = ["dual_address_space.h"], + deps = [ + ":address_space", + ":buddy_address_space", + ":mmu_mapper", + "//driver:hardware_structures", + "//driver/config", + "//port", + ], +) + +cc_library( + name = "dram_allocator", + hdrs = ["dram_allocator.h"], + deps = [ + "//api:dram_buffer", + "//port", + ], +) + +cc_library( + name = "null_dram_allocator", + hdrs = ["null_dram_allocator.h"], + deps = [ + ":dram_allocator", + "//api:dram_buffer", + "//port", + ], +) + +cc_library( + name = "fake_dram_allocator", + srcs = ["fake_dram_allocator.cc"], + hdrs = ["fake_dram_allocator.h"], + deps = [ + ":dram_allocator", + "//api:dram_buffer", + "//port", + ], +) diff --git a/driver/memory/address_space.h b/driver/memory/address_space.h new file mode 100644 index 0000000..cffffe5 --- /dev/null +++ b/driver/memory/address_space.h @@ -0,0 +1,88 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_H_ +#define DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_H_ + +#include +#include + +#include "api/buffer.h" +#include "driver/device_buffer.h" +#include "driver/memory/dma_direction.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A hint that the implementation should use a particular type of address +// space mapping, for systems that have multiple mapping types. +enum class MappingTypeHint { + // No preference. Most mappings should be of this type. + kAny, + + // Use simple address space mappings, if the hardware is capable. + kSimple, + + // Use extended address space mappings, if the hardware is capable. + kExtended, +}; + +// An interface for managing a DarwiNN virtual address space segment. +class AddressSpace { + public: + AddressSpace() = default; + + // This class is neither copyable nor movable. + AddressSpace(const AddressSpace&) = delete; + AddressSpace& operator=(const AddressSpace&) = delete; + + virtual ~AddressSpace() = default; + + // Maps the buffer to the device buffer. Returns the mapped device + // buffer on success. + util::StatusOr MapMemory(const Buffer& buffer) { + return MapMemory(buffer, DmaDirection::kBidirectional, + MappingTypeHint::kAny); + } + + // Same as above but with a hint indicating the buffer transfer direction and + // a hint indicating whether to use simple or extended mappings. + virtual util::StatusOr MapMemory( + const Buffer& buffer, DmaDirection direction, + MappingTypeHint mapping_type) = 0; + + // Same as above but for coherent memory, which may be mapped differently. + virtual util::StatusOr MapCoherentMemory( + const Buffer& buffer, DmaDirection direction, + MappingTypeHint mapping_type) { + return MapMemory(buffer, direction, mapping_type); + } + + // Unmaps the given device buffer. + virtual util::Status UnmapMemory(DeviceBuffer buffer) = 0; + + // Same as above but for coherent memory, which may be handled differently. + virtual util::Status UnmapCoherentMemory(DeviceBuffer buffer) { + return UnmapMemory(buffer); + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_H_ diff --git a/driver/memory/address_space_allocator.h b/driver/memory/address_space_allocator.h new file mode 100644 index 0000000..5044bd2 --- /dev/null +++ b/driver/memory/address_space_allocator.h @@ -0,0 +1,43 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_ALLOCATOR_H_ +#define DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_ALLOCATOR_H_ + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Performs allocations within an address space. +class AddressSpaceAllocator { + public: + virtual ~AddressSpaceAllocator() = default; + + // Allocates |size_bytes| bytes of address space and returns the base address + // of the allocation. + virtual util::StatusOr Allocate(uint64 size_bytes) = 0; + + // Frees the allocation with base address |address| and of size |size_bytes|. + virtual util::Status Free(uint64 address, uint64 size_bytes) = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_ADDRESS_SPACE_ALLOCATOR_H_ diff --git a/driver/memory/address_utilities.h b/driver/memory/address_utilities.h new file mode 100644 index 0000000..a9e979c --- /dev/null +++ b/driver/memory/address_utilities.h @@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Few utilities for manipulating host and device addresses. +// Note: In general, host addresses are pointers (void* buffer), where as device +// addresses are uint64. + +#ifndef DARWINN_DRIVER_MEMORY_ADDRESS_UTILITIES_H_ +#define DARWINN_DRIVER_MEMORY_ADDRESS_UTILITIES_H_ + +#include + +#include "driver/hardware_structures.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/math_util.h" +#include "port/status.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Get the offset into a page for a given address. +static inline uint64 GetPageOffset(uint64 address) { + return address & (kHostPageSize - 1); +} + +// Get the offset into a page for a given address. +static inline uint64 GetPageOffset(const void *buffer) { + return GetPageOffset(reinterpret_cast(buffer)); +} + +// Returns true, if page aligned. +static inline bool IsPageAligned(const void *buffer) { + return (GetPageOffset(buffer) == 0); +} + +// Returns true, if page aligned. +static inline bool IsPageAligned(uint64 address) { + return (GetPageOffset(address) == 0); +} + +// Get the number of pages required to back a given address range. +static inline uint64 GetNumberPages(const void *buffer, size_t size_bytes) { + return MathUtil::CeilOfRatio(GetPageOffset(buffer) + size_bytes, + kHostPageSize); +} + +// Get the number of pages required to back a given address range. +static inline uint64 GetNumberPages(uint64 address, size_t size_bytes) { + return MathUtil::CeilOfRatio(GetPageOffset(address) + size_bytes, + kHostPageSize); +} + +// Get the page-aligned address for a given address +static inline uint64 GetPageAddress(uint64 address) { + return address - GetPageOffset(address); +} + +// Get the page-aligned address for a given buffer. +static inline const void *GetPageAddressForBuffer(const void *buffer) { + return static_cast(buffer) - GetPageOffset(buffer); +} + +// Get the page address (in terms of kHostPageSize) for a given page number. +static constexpr uint64 GetPageAddressFromNumber(uint64 page_num) { + return (page_num << kHostPageShiftBits); +} + +// Get the page number (in terms of kHostPageSize) for a given address. +static constexpr uint64 GetPageNumberFromAddress(uint64 address) { + return address >> kHostPageShiftBits; +} + +// Returns whether the given address satisifes the given alighment. +static inline util::Status IsAligned(const uint8 *buffer, + uint64 alignment_bytes) { + if ((reinterpret_cast(buffer) % alignment_bytes) != 0) { + return util::FailedPreconditionError( + StringPrintf("Buffer is not aligned. address=%p, alignment=%llu.", + buffer, static_cast(alignment_bytes))); + } + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_ADDRESS_UTILITIES_H_ diff --git a/driver/memory/buddy_address_space.cc b/driver/memory/buddy_address_space.cc new file mode 100644 index 0000000..df3ac09 --- /dev/null +++ b/driver/memory/buddy_address_space.cc @@ -0,0 +1,92 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/buddy_address_space.h" + +#include "api/buffer.h" +#include "driver/memory/address_utilities.h" +#include "driver/memory/mmio_address_space.h" +#include "driver/memory/mmu_mapper.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +BuddyAddressSpace::BuddyAddressSpace(uint64 device_virtual_address_start, + uint64 device_virtual_address_size_bytes, + MmuMapper* mmu_mapper) + : MmioAddressSpace(device_virtual_address_start, + device_virtual_address_size_bytes, mmu_mapper), + allocator_(device_virtual_address_start, + device_virtual_address_size_bytes) {} + +util::StatusOr BuddyAddressSpace::MapMemory( + const Buffer& buffer, DmaDirection direction, + MappingTypeHint mapping_type) { + TRACE_SCOPE("BuddyAddressSpace::MapMemory"); + const void* ptr = buffer.IsPtrType() ? buffer.ptr() : nullptr; + if (!ptr && buffer.IsPtrType()) { + return util::InvalidArgumentError( + "Cannot map an invalid host-memory-backed Buffer."); + } + + const size_t size_bytes = buffer.size_bytes(); + if (size_bytes == 0) { + return util::InvalidArgumentError("Cannot map 0 bytes."); + } + + auto num_requested_pages = GetNumberPages(ptr, size_bytes); + const uint64 allocation_size = num_requested_pages * kHostPageSize; + + StdMutexLock lock(&mutex_); + ASSIGN_OR_RETURN(uint64 device_va, allocator_.Allocate(allocation_size)); + + // Make sure the allocation if freed upon error. + auto cleanup = MakeCleanup( + [this, device_va, allocation_size]() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + CHECK_OK(allocator_.Free(device_va, allocation_size)); + }); + RETURN_IF_ERROR(Map(buffer, device_va, direction)); + + cleanup.release(); + return DeviceBuffer(device_va + GetPageOffset(ptr), size_bytes); +} + +util::Status BuddyAddressSpace::UnmapMemory(DeviceBuffer buffer) { + TRACE_SCOPE("BuddyAddressSpace::UnmapMemory"); + StdMutexLock lock(&mutex_); + const uint64 device_address = buffer.device_address(); + const size_t size_bytes = buffer.size_bytes(); + + auto num_pages = GetNumberPages(device_address, size_bytes); + const uint64 allocation_size = num_pages * kHostPageSize; + const uint64 device_aligned_va = GetPageAddress(device_address); + + RETURN_IF_ERROR(Unmap(device_aligned_va, num_pages)); + RETURN_IF_ERROR(allocator_.Free(device_aligned_va, allocation_size)); + + return util::Status(); // OK. +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/buddy_address_space.h b/driver/memory/buddy_address_space.h new file mode 100644 index 0000000..a9ab351 --- /dev/null +++ b/driver/memory/buddy_address_space.h @@ -0,0 +1,72 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_BUDDY_ADDRESS_SPACE_H_ +#define DARWINN_DRIVER_MEMORY_BUDDY_ADDRESS_SPACE_H_ + +#include + +#include // NOLINT + +#include "api/buffer.h" +#include "driver/memory/buddy_allocator.h" +#include "driver/memory/dma_direction.h" +#include "driver/memory/mmio_address_space.h" +#include "driver/memory/mmu_mapper.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A buddy memory allocator for DarwiNN virtual address space segment. +// https://en.wikipedia.org/wiki/Buddy_memory_allocation +class BuddyAddressSpace final : public MmioAddressSpace { + public: + using AddressSpace::MapMemory; // Allows for proper overload resolution. + + BuddyAddressSpace(uint64 device_virtual_address_start, + uint64 device_virtual_address_size_bytes, + MmuMapper* mmu_mapper); + + // This class is neither copyable nor movable. + BuddyAddressSpace(const BuddyAddressSpace&) = delete; + BuddyAddressSpace& operator=(const BuddyAddressSpace&) = delete; + + ~BuddyAddressSpace() override = default; + + // Maps the given host buffer to the device buffer. Returns the mapped device + // buffer on success. + util::StatusOr MapMemory(const Buffer& buffer, + DmaDirection direction, + MappingTypeHint mapping_type) override + LOCKS_EXCLUDED(mutex_); + + // Unmaps the given device buffer. + util::Status UnmapMemory(DeviceBuffer buffer) override LOCKS_EXCLUDED(mutex_); + + private: + mutable std::mutex mutex_; + + // Allocator that manages address space resources. + BuddyAllocator allocator_ GUARDED_BY(mutex_); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_BUDDY_ADDRESS_SPACE_H_ diff --git a/driver/memory/buddy_allocator.cc b/driver/memory/buddy_allocator.cc new file mode 100644 index 0000000..f04f29e --- /dev/null +++ b/driver/memory/buddy_allocator.cc @@ -0,0 +1,179 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/buddy_allocator.h" + +#include "absl/strings/str_format.h" +#include "driver/memory/address_utilities.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Number of bits in the address space. +constexpr int kAddressSpaceBits = 64; + +// Number of bins accounts for powers of 2 in a 64-bit address space, but does +// not need to include bins for sizes smaller than the page size. +constexpr int kNumBins = kAddressSpaceBits - kHostPageShiftBits; + +// Based on: +// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float +uint64 RoundUpToNextPowerOfTwo(uint64 x) { + x--; + x |= x >> 1; // handle 2 bit numbers + x |= x >> 2; // handle 4 bit numbers + x |= x >> 4; // handle 8 bit numbers + x |= x >> 8; // handle 16 bit numbers + x |= x >> 16; // handle 32 bit numbers + x |= x >> 32; // handle 64 bit numbers + x++; + + return x; +} + +// Returns the bin index given an order. The unit of allocation is a host page, +// so the smallest bin (bin 0) is for anything that is <= host page size. +int GetBinFromOrder(int order) { + CHECK_GE(order, kHostPageShiftBits); + return order - kHostPageShiftBits; +} + +// Returns the order for a given bin. For example, bin 2 is of 4 times the host +// page size. On x86 it is 2^(12+2). +int GetOrderFromBin(int bin) { return bin + kHostPageShiftBits; } + +// For a given allocation request size, returns the index to the bin (i.e. for +// indexing block_ ) that size belongs to. This based on: +// https://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightModLookup +// Rationale: +// pow(2, i) for 0 <= i < 32 have distinct modulo by 37. We use that property to +// perform fast lookup. +int FindBin(uint64 num_pages) { + const uint64 nearest_power_of_two = RoundUpToNextPowerOfTwo(num_pages); + // The trick below only works up to 2^31. + CHECK_LE(nearest_power_of_two, 1ULL << 31); + static constexpr int + kMod37BitPosition[] = // map a bit value mod 37 to its position + {32, 0, 1, 26, 2, 23, 27, 0, 3, 16, 24, 30, 28, 11, 0, 13, 4, 7, 17, + 0, 25, 22, 31, 15, 29, 10, 12, 6, 0, 21, 14, 9, 5, 20, 8, 19, 18}; + const int num_zero = kMod37BitPosition[nearest_power_of_two % 37]; + return std::max(GetBinFromOrder(num_zero), 0); +} + +// Returns the number of pages required to store the specified size in bytes. +uint64 GetNumPages(uint64 size_bytes) { + const int num_pages = size_bytes >> kHostPageShiftBits; + const int spillover_page = + (size_bytes & ((1 << kHostPageShiftBits) - 1)) ? 1 : 0; + return num_pages + spillover_page; +} + +} // namespace + +BuddyAllocator::BuddyAllocator(uint64 address_space_start, + uint64 address_space_size_bytes) + : address_space_start_(address_space_start), + free_blocks_(kNumBins), + allocated_blocks_(kNumBins) { + uint64 offset = 0; + // Initialize all bins. In the worst case we'd miss up to kHostPageSize - 1 + // bytes. + for (int i = kAddressSpaceBits - 1; i >= kHostPageShiftBits; --i) { + const uint64 mask = (1ULL << i); + if (address_space_size_bytes & mask) { + free_blocks_[GetBinFromOrder(i)].insert(offset); + offset += mask; + } + } +} + +util::StatusOr BuddyAllocator::Allocate(uint64 size_bytes) { + StdMutexLock lock(&mutex_); + if (size_bytes == 0) { + return util::InvalidArgumentError("Cannot allocate 0 bytes."); + } + + const uint64 num_requested_pages = GetNumPages(size_bytes); + const int desirable_bin = FindBin(num_requested_pages * kHostPageSize); + int nearest_bin = desirable_bin; + + // Find the nearest bin that has at least something left. + while (nearest_bin < free_blocks_.size() && + free_blocks_[nearest_bin].empty()) { + ++nearest_bin; + } + if (nearest_bin >= free_blocks_.size()) { + return util::ResourceExhaustedError( + absl::StrFormat("Can't allocate for 0x%llx bytes.", size_bytes)); + } + + const auto& block = free_blocks_[nearest_bin].begin(); + const uint64 offset = *block; + + free_blocks_[nearest_bin].erase(block); + allocated_blocks_[desirable_bin].insert(offset); + + // If nearest bin != desirable bin, insert blocks produced by splitting higher + // order ones + for (int i = nearest_bin - 1; i >= desirable_bin; --i) { + const uint64 split_offset = offset + (1ULL << GetOrderFromBin(i)); + free_blocks_[i].insert(split_offset); + } + + const uint64 allocated_address = offset + address_space_start_; + return allocated_address; +} + +util::Status BuddyAllocator::Free(uint64 address, uint64 size_bytes) { + StdMutexLock lock(&mutex_); + const uint64 num_pages = GetNumPages(size_bytes); + const int bin = FindBin(num_pages * kHostPageSize); + + const uint64 offset = address - address_space_start_; + auto allocated_iterator = allocated_blocks_[bin].find(offset); + if (allocated_iterator == allocated_blocks_[bin].end()) { + return util::InvalidArgumentError(absl::StrFormat( + "Allocated block with address 0x%llx and size 0x%llx not found.", + address, size_bytes)); + } + allocated_blocks_[bin].erase(allocated_iterator); + + uint64 coalesced_offset = offset; + for (int buddy_bin = bin; buddy_bin < free_blocks_.size(); ++buddy_bin) { + // Find nearby block ("buddy") if any. + const uint64 buddy_offset = + coalesced_offset ^ (1ULL << GetOrderFromBin(buddy_bin)); + + auto buddy_iterator = free_blocks_[buddy_bin].find(buddy_offset); + if (buddy_iterator != free_blocks_[buddy_bin].end()) { + // Merging with the buddy at buddy_offset. + free_blocks_[buddy_bin].erase(buddy_iterator); + coalesced_offset &= buddy_offset; + } else { + // We are done - can't coalesce more. Insert the block to the current bin. + free_blocks_[buddy_bin].insert(coalesced_offset); + break; + } + } + + return util::Status(); // OK. +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/buddy_allocator.h b/driver/memory/buddy_allocator.h new file mode 100644 index 0000000..a480ddf --- /dev/null +++ b/driver/memory/buddy_allocator.h @@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_BUDDY_ALLOCATOR_H_ +#define DARWINN_DRIVER_MEMORY_BUDDY_ALLOCATOR_H_ + +#include +#include // NOLINT(build/c++11) +#include +#include + +#include "driver/memory/address_space_allocator.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + + +namespace platforms { +namespace darwinn { +namespace driver { + +// A buddy address space allocator. +// https://en.wikipedia.org/wiki/Buddy_memory_allocation +// +// Note that allocations in this buddy allocator are made on 4KB aligned +// boundaries and are 4KB granular in size, even if the requested size is not +// 4KB granular. +// +// Class is thread unsafe. +class BuddyAllocator : public AddressSpaceAllocator { + public: + // Constructs an allocator that will allocate from a contiguous address range + // starting with |address_space_start| and of size |address_space_size_bytes|. + // Allocations are always aligned on 4KB boundaries and are increments of 4KB + // in size. + BuddyAllocator(uint64 address_space_start, uint64 address_space_size_bytes); + + ~BuddyAllocator() override = default; + + ////////////////////////////////////////////////////////////////////////////// + // Implementation of Allocator interface + // + util::StatusOr Allocate(uint64 size_bytes) override + LOCKS_EXCLUDED(mutex_); + util::Status Free(uint64 address, uint64 size_bytes) override + LOCKS_EXCLUDED(mutex_); + + private: + // Starting address of the space being allocated from. + const uint64 address_space_start_; + + // Sets of free blocks, indexed by order. + std::vector> free_blocks_ GUARDED_BY(mutex_); + + // Sets of allocated blocks, indexed by order. + std::vector> allocated_blocks_ GUARDED_BY(mutex_); + + mutable std::mutex mutex_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_BUDDY_ALLOCATOR_H_ diff --git a/driver/memory/dma_direction.h b/driver/memory/dma_direction.h new file mode 100644 index 0000000..9512694 --- /dev/null +++ b/driver/memory/dma_direction.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_DMA_DIRECTION_H_ +#define DARWINN_DRIVER_MEMORY_DMA_DIRECTION_H_ + +namespace platforms { +namespace darwinn { +namespace driver { + +// Mimics the Linux kernel dma_data_direction enum in the DMA API. +// This indicates the direction which data moves during a DMA transfer, and is a +// useful hint to pass to the kernel when mapping buffers. +enum class DmaDirection { + // DMA_TO_DEVICE: CPU caches are flushed at mapping time. + kToDevice = 1, + // DMA_FROM_DEVICE: CPU caches are invalidated at unmapping time. + kFromDevice = 2, + // DMA_BIDIRECTIONAL: Both of the above. + kBidirectional = 0, +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_DMA_DIRECTION_H_ diff --git a/driver/memory/dram_allocator.h b/driver/memory/dram_allocator.h new file mode 100644 index 0000000..e8f61fe --- /dev/null +++ b/driver/memory/dram_allocator.h @@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_DRAM_ALLOCATOR_H_ +#define DARWINN_DRIVER_MEMORY_DRAM_ALLOCATOR_H_ + +#include "api/dram_buffer.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// An abstract class for DRAM allocator. Each chip will have a concrete +// implementation. +class DramAllocator { + public: + DramAllocator() = default; + virtual ~DramAllocator() = default; + + // This class is neither copyable nor movable. + DramAllocator(const DramAllocator&) = delete; + DramAllocator& operator=(const DramAllocator&) = delete; + + // Open and close the allocator. Buffer allocation can happen even when the + // allocator is closed but thos buffers should not be used when allocator is + // closed. + virtual util::Status Open() = 0; + virtual util::Status Close() = 0; + + // Allocates and returns a DRAM buffer of requested size. It returns an error + // if there is not enough space. + virtual util::StatusOr> AllocateBuffer( + size_t size_bytes) = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_DRAM_ALLOCATOR_H_ diff --git a/driver/memory/dual_address_space.cc b/driver/memory/dual_address_space.cc new file mode 100644 index 0000000..01a8791 --- /dev/null +++ b/driver/memory/dual_address_space.cc @@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/dual_address_space.h" + +#include "driver/hardware_structures.h" +#include "driver/memory/buddy_address_space.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +DualAddressSpace::DualAddressSpace( + const config::ChipStructures& chip_structures, MmuMapper* mmu_mapper) { + const int num_simple_entries = + GetNumSimplePageTableEntries(chip_structures.num_page_table_entries); + + simple_ = gtl::MakeUnique( + 0, kHostPageSize * num_simple_entries, mmu_mapper); + + extended_ = gtl::MakeUnique( + kExtendedAddressSpaceStart, + kExtendedPageTableEntryAddressableBytes * + GetNumExtendedPageTableEntries( + chip_structures.num_page_table_entries), + mmu_mapper); +} + +util::StatusOr DualAddressSpace::MapMemory( + const Buffer& buffer, DmaDirection direction, + MappingTypeHint mapping_type) { + switch (mapping_type) { + case MappingTypeHint::kSimple: + return simple_->MapMemory(buffer, direction, mapping_type); + + case MappingTypeHint::kExtended: + case MappingTypeHint::kAny: + return extended_->MapMemory(buffer, direction, mapping_type); + } +} + +util::Status DualAddressSpace::UnmapMemory(DeviceBuffer buffer) { + return DetermineSource(buffer)->UnmapMemory(buffer); +} + +AddressSpace* DualAddressSpace::DetermineSource( + const DeviceBuffer& device_buffer) const { + if (device_buffer.device_address() & kExtendedVirtualAddressBit) { + return extended_.get(); + } else { + return simple_.get(); + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/dual_address_space.h b/driver/memory/dual_address_space.h new file mode 100644 index 0000000..01683fa --- /dev/null +++ b/driver/memory/dual_address_space.h @@ -0,0 +1,65 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_DUAL_ADDRESS_SPACE_H_ +#define DARWINN_DRIVER_MEMORY_DUAL_ADDRESS_SPACE_H_ + +#include "driver/config/chip_structures.h" +#include "driver/memory/address_space.h" +#include "driver/memory/mmu_mapper.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// An address space implementation that works with a split simple/extended page +// table. +class DualAddressSpace final : public AddressSpace { + public: + using AddressSpace::MapMemory; // Allows for proper overload resolution. + + DualAddressSpace(const config::ChipStructures& chip_structures, + MmuMapper* mmu_mapper); + + // This class is neither copyable nor movable. + DualAddressSpace(const DualAddressSpace&) = delete; + DualAddressSpace& operator=(const DualAddressSpace&) = delete; + + virtual ~DualAddressSpace() = default; + + // Maps the buffer to the device buffer. Returns the mapped device + // buffer on success. + util::StatusOr MapMemory(const Buffer& buffer, + DmaDirection direction, + MappingTypeHint mapping_type) override; + + // Unmaps the given device buffer. + util::Status UnmapMemory(DeviceBuffer buffer) override; + + private: + // Determines which address space the device buffer was allocated from. + AddressSpace* DetermineSource(const DeviceBuffer& device_buffer) const; + + // Underlying simple address space. + std::unique_ptr simple_; + + // Underlying extended address space. + std::unique_ptr extended_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_DUAL_ADDRESS_SPACE_H_ diff --git a/driver/memory/fake_dram_allocator.cc b/driver/memory/fake_dram_allocator.cc new file mode 100644 index 0000000..2bd43e9 --- /dev/null +++ b/driver/memory/fake_dram_allocator.cc @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/fake_dram_allocator.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +FakeDramBuffer::FakeDramBuffer(size_t size_bytes) : size_bytes_(size_bytes) { + ptr_ = malloc(size_bytes); +} + +FakeDramBuffer::~FakeDramBuffer() { free(ptr_); } + +util::Status FakeDramBuffer::ReadFrom(void* source) { + memcpy(ptr_, source, size_bytes_); + return util::OkStatus(); +} + +util::Status FakeDramBuffer::WriteTo(void* destination) { + memcpy(destination, ptr_, size_bytes_); + return util::OkStatus(); +} + +util::StatusOr> FakeDramAllocator::AllocateBuffer( + size_t size_bytes) { + return {std::make_shared(size_bytes)}; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/fake_dram_allocator.h b/driver/memory/fake_dram_allocator.h new file mode 100644 index 0000000..980790d --- /dev/null +++ b/driver/memory/fake_dram_allocator.h @@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_FAKE_DRAM_ALLOCATOR_H_ +#define DARWINN_DRIVER_MEMORY_FAKE_DRAM_ALLOCATOR_H_ + +#include "api/dram_buffer.h" +#include "driver/memory/dram_allocator.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Pretends to be an on-chip DRAM buffer while it actually is a host DRAM +// buffer. This is useful for reference driver and such. +class FakeDramBuffer : public DramBuffer { + public: + FakeDramBuffer(size_t size_bytes); + ~FakeDramBuffer() override; + + int fd() const override { return 1; } + size_t size_bytes() const override { return size_bytes_; } + + util::Status ReadFrom(void* source) override; + util::Status WriteTo(void* destination) override; + + private: + // Size of the buffer. + size_t size_bytes_; + + // Pointer to start of the buffer. + void* ptr_; +}; + +// A DRAM allocator that creates fake DRAM buffers. This is useful for reference +// driver and such. +class FakeDramAllocator : public DramAllocator { + public: + FakeDramAllocator() = default; + ~FakeDramAllocator() override = default; + + // This class is neither copyable nor movable. + FakeDramAllocator(const FakeDramAllocator&) = delete; + FakeDramAllocator& operator=(const FakeDramAllocator&) = delete; + + util::Status Open() override { return util::OkStatus(); } + util::Status Close() override { return util::OkStatus(); } + + util::StatusOr> AllocateBuffer( + size_t size_bytes) override; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_FAKE_DRAM_ALLOCATOR_H_ diff --git a/driver/memory/fake_mmu_mapper.cc b/driver/memory/fake_mmu_mapper.cc new file mode 100644 index 0000000..fdddbb0 --- /dev/null +++ b/driver/memory/fake_mmu_mapper.cc @@ -0,0 +1,96 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/fake_mmu_mapper.h" + +#include "driver/hardware_structures.h" +#include "driver/memory/address_utilities.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status FakeMmuMapper::DoMap(const void *buffer, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) { + StdMutexLock lock(&mutex_); + CHECK(IsPageAligned(buffer)); + CHECK(IsPageAligned(device_virtual_address)); + + const uint8 *host_addr_start = static_cast(buffer); + for (int i = 0; i < num_pages; ++i) { + const uint8 *host_addr = i * kHostPageSize + host_addr_start; + uint64 device_addr = i * kHostPageSize + device_virtual_address; + + CHECK(device_to_host_.find(device_addr) == device_to_host_.end()); + device_to_host_.emplace(device_addr, host_addr); + } + + return util::Status(); // OK +} + +util::Status FakeMmuMapper::DoMap(int fd, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) { + return DoMap(reinterpret_cast(fd * kHostPageSize), num_pages, + device_virtual_address, direction); +} + +util::Status FakeMmuMapper::DoUnmap(const void *buffer, int num_pages, + uint64 device_virtual_address) { + StdMutexLock lock(&mutex_); + CHECK(IsPageAligned(buffer)); + CHECK(IsPageAligned(device_virtual_address)); + + for (int i = 0; i < num_pages; ++i) { + uint64 device_addr = i * kHostPageSize + device_virtual_address; + + // TODO: Validate that the device virtual address and buffer + // corresponds to the buffer that was originally mapped. + CHECK(device_to_host_.find(device_addr) != device_to_host_.end()); + device_to_host_.erase(device_addr); + } + + return util::Status(); // OK +} + +util::Status FakeMmuMapper::DoUnmap(int fd, int num_pages, + uint64 device_virtual_address) { + return DoUnmap(reinterpret_cast(fd * kHostPageSize), num_pages, + device_virtual_address); +} + +util::StatusOr FakeMmuMapper::TranslateDeviceAddress( + uint64 device_virtual_address) const { + uint64 aligned_device_addr = GetPageAddress(device_virtual_address); + + StdMutexLock lock(&mutex_); + auto iter = device_to_host_.find(aligned_device_addr); + if (iter == device_to_host_.end()) { + return util::NotFoundError("Device address not mapped."); + } + + uint8 *aligned_host_addr = const_cast(iter->second); + auto host_address = aligned_host_addr + GetPageOffset(device_virtual_address); + + CHECK(host_address != nullptr); // StatusOr doesn't like nullptr! + return host_address; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/fake_mmu_mapper.h b/driver/memory/fake_mmu_mapper.h new file mode 100644 index 0000000..fa4e88f --- /dev/null +++ b/driver/memory/fake_mmu_mapper.h @@ -0,0 +1,73 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_FAKE_MMU_MAPPER_H_ +#define DARWINN_DRIVER_MEMORY_FAKE_MMU_MAPPER_H_ + +#include + +#include "driver/memory/dma_direction.h" +#include "driver/memory/mmu_mapper.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A fake MMU mapper implementation that does not accurately model +// the underlying hardware, but behaves the same way. +class FakeMmuMapper : public MmuMapper { + public: + FakeMmuMapper() {} + ~FakeMmuMapper() override {} + + // This class is neither copyable nor movable. + FakeMmuMapper(const FakeMmuMapper&) = delete; + FakeMmuMapper& operator=(const FakeMmuMapper&) = delete; + + // Overrides from MmuMapper + util::Status Open(int num_simple_page_table_entries_requested) override { + return util::Status(); // OK + } + util::Status Close() override { return util::Status(); } + util::StatusOr TranslateDeviceAddress( + uint64 device_virtual_address) const override; + + protected: + util::Status DoMap(const void* buffer, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) override; + util::Status DoUnmap(const void* buffer, int num_pages, + uint64 device_virtual_address) override; + + // Fake mapping: assuming physical address = fd * kHostPageSize. + util::Status DoMap(int fd, int num_pages, uint64 device_virtual_address, + DmaDirection direction) override; + util::Status DoUnmap(int fd, int num_pages, + uint64 device_virtual_address) override; + + // "Page table" to track device addr to host mappings. + std::map device_to_host_ GUARDED_BY(mutex_); + + // Guards device_to_host_. + mutable std::mutex mutex_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_FAKE_MMU_MAPPER_H_ diff --git a/driver/memory/mmio_address_space.cc b/driver/memory/mmio_address_space.cc new file mode 100644 index 0000000..f338c92 --- /dev/null +++ b/driver/memory/mmio_address_space.cc @@ -0,0 +1,95 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/mmio_address_space.h" + +#include "api/buffer.h" +#include "driver/memory/address_utilities.h" +#include "port/errors.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status MmioAddressSpace::Map(const Buffer& buffer, uint64 device_address, + DmaDirection direction) { + TRACE_SCOPE("MmioAddressSpace::Map"); + CHECK(IsPageAligned(device_address)); + + StdMutexLock lock(&mutex_); + + // If already mapped, fail. + // TODO: Add a finer grained check, e.g., overlap, if necessary? + if (mapped_.find(device_address) != mapped_.end()) { + return util::InvalidArgumentError( + "Trying to map a segment that is already mapped."); + } + + RETURN_IF_ERROR(mmu_mapper_->Map(buffer, device_address, direction)); + + // Track mapped segments. + // Make a copy of Buffer since the given buffer may change later. + auto insert_result = mapped_.insert( + {device_address, buffer}); + CHECK(insert_result.second); + + // buffer.ptr() may or may not be vaild. + // TODO: print out buffer address if the buffer has valid ptr(). + VLOG(4) << StringPrintf( + "MapMemory() page-aligned : device_address = 0x%016llx", + static_cast(device_address)); // NOLINT(runtime/int) + + return util::Status(); // OK +} + +util::Status MmioAddressSpace::Unmap(uint64 device_address, + int num_released_pages) { + TRACE_SCOPE("MmioAddressSpace::Unmap"); + // TODO: verify num_released_pages if the Buffer is backed by host + // memory. + CHECK(IsPageAligned(device_address)); + + StdMutexLock lock(&mutex_); + + auto find_result = mapped_.find(device_address); + if (find_result == mapped_.end()) { + return util::InvalidArgumentError( + "Trying to ummap a segment that is not already mapped."); + } + + // Need to pass the Buffer object as the MMU mapper might require the backing + // file descriptor underneath. + RETURN_IF_ERROR(mmu_mapper_->Unmap(find_result->second, device_address)); + + // buffer.ptr() may or may not be vaild. + // TODO: print out buffer address if the buffer has valid ptr(). + VLOG(4) << StringPrintf( + "UnmapMemory() page-aligned : device_address = 0x%016llx, num_pages = %d", + static_cast(device_address), // NOLINT(runtime/int) + num_released_pages); + + mapped_.erase(find_result); + + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/mmio_address_space.h b/driver/memory/mmio_address_space.h new file mode 100644 index 0000000..45a5b08 --- /dev/null +++ b/driver/memory/mmio_address_space.h @@ -0,0 +1,104 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_MMIO_ADDRESS_SPACE_H_ +#define DARWINN_DRIVER_MEMORY_MMIO_ADDRESS_SPACE_H_ + +#include +#include +#include // NOLINT + +#include "api/buffer.h" +#include "driver/memory/address_space.h" +#include "driver/memory/address_utilities.h" +#include "driver/memory/dma_direction.h" +#include "driver/memory/mmu_mapper.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A class to manage a DarwiNN virtual address space segment when mmio is used. +class MmioAddressSpace : public AddressSpace { + public: + MmioAddressSpace(uint64 device_virtual_address_start, + uint64 device_virtual_address_size_bytes, + MmuMapper* mmu_mapper) + : AddressSpace(), + device_virtual_address_start_(device_virtual_address_start), + device_virtual_address_size_bytes_(device_virtual_address_size_bytes), + mmu_mapper_(mmu_mapper) { + CHECK(mmu_mapper != nullptr); + CHECK(IsPageAligned(device_virtual_address_start)); + CHECK(IsPageAligned(device_virtual_address_size_bytes)); + } + + // This class is neither copyable nor movable. + MmioAddressSpace(const MmioAddressSpace&) = delete; + MmioAddressSpace& operator=(const MmioAddressSpace&) = delete; + + virtual ~MmioAddressSpace() = default; + + protected: + // Maps the entire given Buffer, and stores the mapping information. Returns + // an error if trying to map already mapped Buffer. + util::Status Map(const Buffer& buffer, uint64 device_address, + DmaDirection direction) LOCKS_EXCLUDED(mutex_); + + // Checks and unmaps device virtual address. |device_address| must be + // page-aligned. + util::Status Unmap(uint64 device_address, int num_released_pages) + LOCKS_EXCLUDED(mutex_); + + // Member accessors for inherited classes. + uint64 device_virtual_address_start() const { + return device_virtual_address_start_; + } + uint64 device_virtual_address_size_bytes() const { + return device_virtual_address_size_bytes_; + } + + // Returns last device virtual address. + uint64 GetLastDeviceVirtualAddress() const { + return device_virtual_address_start_ + device_virtual_address_size_bytes_; + } + + private: + // Device address space start. + const uint64 device_virtual_address_start_; + + // Device address space size in bytes. + const uint64 device_virtual_address_size_bytes_; + + // Underlying MMU mapper. + MmuMapper* const mmu_mapper_ GUARDED_BY(mutex_); + + // Guards |mmu_mapper_| and |mapped_|. + mutable std::mutex mutex_; + + // Tracks already mapped segments. + // key - aligned device virtual address. + // value - {host address, number of mapped pages} + std::map mapped_ GUARDED_BY(mutex_); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_MMIO_ADDRESS_SPACE_H_ diff --git a/driver/memory/mmu_mapper.cc b/driver/memory/mmu_mapper.cc new file mode 100644 index 0000000..ba5bd8d --- /dev/null +++ b/driver/memory/mmu_mapper.cc @@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/mmu_mapper.h" + +#include "driver/memory/address_utilities.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status MmuMapper::Map(const Buffer &buffer, uint64 device_virtual_address, + DmaDirection direction) { + TRACE_SCOPE("MmuMapper::Map"); + // Buffers backed by file descriptors do not have valid ptr(). + const void *ptr = buffer.FileDescriptorBacked() ? nullptr : buffer.ptr(); + if (buffer.IsPtrType() && ptr == nullptr) { + return util::InvalidArgumentError("Cannot map a Buffer of nullptr."); + } + + const size_t size_bytes = buffer.size_bytes(); + if (size_bytes == 0) { + return util::InvalidArgumentError("Cannot map 0 bytes."); + } + + auto num_requested_pages = GetNumberPages(ptr, size_bytes); + + // Buffers backed by file descriptors are handled differently. + if (buffer.FileDescriptorBacked()) { + return DoMap(buffer.fd(), num_requested_pages, device_virtual_address, + direction); + } else { + auto aligned_buffer_addr = GetPageAddressForBuffer(ptr); + return DoMap(aligned_buffer_addr, num_requested_pages, + device_virtual_address, direction); + } +} + +util::Status MmuMapper::Unmap(const Buffer &buffer, + uint64 device_virtual_address) { + TRACE_SCOPE("MmuMapper::Unmap"); + // Buffers backed by file descriptors do not have valid ptr(). + const void *ptr = buffer.FileDescriptorBacked() ? nullptr : buffer.ptr(); + if (buffer.IsPtrType() && ptr == nullptr) { + return util::InvalidArgumentError("Cannot unmap a Buffer of nullptr."); + } + + const size_t size_bytes = buffer.size_bytes(); + if (size_bytes == 0) { + return util::InvalidArgumentError("Cannot unmap 0 bytes."); + } + + auto num_mapped_pages = GetNumberPages(ptr, size_bytes); + + // Buffers backed by file descriptors are handled differently. + if (buffer.FileDescriptorBacked()) { + return DoUnmap(buffer.fd(), num_mapped_pages, device_virtual_address); + } else { + auto aligned_buffer_addr = GetPageAddressForBuffer(ptr); + return DoUnmap(aligned_buffer_addr, num_mapped_pages, + device_virtual_address); + } +} + + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/mmu_mapper.h b/driver/memory/mmu_mapper.h new file mode 100644 index 0000000..aed79cc --- /dev/null +++ b/driver/memory/mmu_mapper.h @@ -0,0 +1,111 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_MMU_MAPPER_H_ +#define DARWINN_DRIVER_MEMORY_MMU_MAPPER_H_ + +#include "api/buffer.h" +#include "driver/device_buffer.h" +#include "driver/memory/dma_direction.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Abstract class for mapping memory on device MMU. +class MmuMapper { + public: + virtual ~MmuMapper() = default; + + // Opens / Closes the MMU interface. + // - Reserve |num_simple_page_table_entries_requested| page table entries for + // simple indexing. Remaining entries will be used for extended addressing. + virtual util::Status Open(int num_simple_page_table_entries_requested) = 0; + virtual util::Status Close() = 0; + + // Maps |num_pages| from the memory backing |buffer| to + // |device_virtual_address|. + util::Status Map(const Buffer &buffer, uint64 device_virtual_address) { + return Map(buffer, device_virtual_address, DmaDirection::kBidirectional); + } + + // Same as above but with a hint indicating the buffer transfer direction. + util::Status Map(const Buffer &buffer, uint64 device_virtual_address, + DmaDirection direction); + + // Unmaps previously mapped Buffer. + util::Status Unmap(const Buffer &buffer, uint64 device_virtual_address); + + // Translates device address to host virtual address. This function is + // typically not implemented and will return an UNIMPLEMENTED Status. It is + // only useful when MMU needs to be modeled directly (as is the case when + // using IpCore without the HIB, or with no MMU). + // + // Note that the device address here is the address that is output by the + // hardware, which may be physical or virtual, depending if an MMU is present + // or not. + virtual util::StatusOr TranslateDeviceAddress( + uint64 device_address) const { + return util::UnimplementedError("Translate not supported."); + } + + // Determines if a virtual address (obtained from TranslateDeviceAddress) + // points into the extended page tables of this MMU. If so, reads + // "size_in_bytes" bytes of data from "address" to "buffer" and returns true. + // Generally this is false, except in certain simulations where the MMU is + // modeled directly. + virtual bool TryReadExtendedPageTable(const void *address, void *buffer, + int size_in_bytes) const { + return false; + } + + protected: + // Maps |num_pages| from |buffer| (the host virtual address) to + // |device_virtual_address|. All addresses must be page aligned. + // Called by public version of Map when the buffer is backed by host memory. + virtual util::Status DoMap(const void *buffer, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) = 0; + + // Maps file descriptor to |device_virtual_address|. Default = unimplemented. + virtual util::Status DoMap(int fd, int num_pages, + uint64 device_virtual_address, + DmaDirection direction) { + return util::UnimplementedError( + "File descriptor-backed mapping not supported."); + } + + // Unmaps previously mapped addresses. + // Called by public version of Unmap when the buffer is backed by host memory. + virtual util::Status DoUnmap(const void *buffer, int num_pages, + uint64 device_virtual_address) = 0; + + // Unmaps previously mapped file descriptor based buffer. Default = + // unimplemented. + + virtual util::Status DoUnmap(int fd, int num_pages, + uint64 device_virtual_address) { + return util::UnimplementedError( + "File descriptor-backed unmapping not supported."); + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_MMU_MAPPER_H_ diff --git a/driver/memory/nop_address_space.cc b/driver/memory/nop_address_space.cc new file mode 100644 index 0000000..1bf1d4f --- /dev/null +++ b/driver/memory/nop_address_space.cc @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/memory/nop_address_space.h" + +#include "api/buffer.h" +#include "port/errors.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::StatusOr NopAddressSpace::MapMemory( + const Buffer& buffer, DmaDirection direction, + MappingTypeHint mapping_type) { + if (!buffer.IsValid()) { + return util::InvalidArgumentError("Invalid buffer."); + } + + return DeviceBuffer(reinterpret_cast(buffer.ptr()), + buffer.size_bytes()); +} + +util::StatusOr NopAddressSpace::Translate( + const DeviceBuffer& buffer) const { + return Buffer(reinterpret_cast(buffer.device_address()), + buffer.size_bytes()); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/memory/nop_address_space.h b/driver/memory/nop_address_space.h new file mode 100644 index 0000000..6c8d01d --- /dev/null +++ b/driver/memory/nop_address_space.h @@ -0,0 +1,66 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_NOP_ADDRESS_SPACE_H_ +#define DARWINN_DRIVER_MEMORY_NOP_ADDRESS_SPACE_H_ + +#include +#include + +#include "api/buffer.h" +#include "driver/device_buffer.h" +#include "driver/memory/address_space.h" +#include "driver/memory/dma_direction.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// No-op address space implementation. MapMemory and UnmapMemory is no-op. Host +// address equals to device virtual address. +class NopAddressSpace : public AddressSpace { + public: + using AddressSpace::MapMemory; // Allows for proper overload resolution. + + NopAddressSpace() = default; + + // This class is neither copyable nor movable. + NopAddressSpace(const NopAddressSpace&) = delete; + NopAddressSpace& operator=(const NopAddressSpace&) = delete; + + virtual ~NopAddressSpace() = default; + + // Maps the given host buffer to the device buffer. Returns the mapped device + // buffer on success. + util::StatusOr MapMemory(const Buffer& buffer, + DmaDirection direction, + MappingTypeHint mapping_type) override; + + // Unmaps the given device address range. + util::Status UnmapMemory(DeviceBuffer buffer) override { + return util::Status(); // OK + } + + // Translates device buffer to host buffer. + util::StatusOr Translate( + const DeviceBuffer& buffer) const; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_NOP_ADDRESS_SPACE_H_ diff --git a/driver/memory/null_dram_allocator.h b/driver/memory/null_dram_allocator.h new file mode 100644 index 0000000..00554a3 --- /dev/null +++ b/driver/memory/null_dram_allocator.h @@ -0,0 +1,56 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MEMORY_NULL_DRAM_ALLOCATOR_H_ +#define DARWINN_DRIVER_MEMORY_NULL_DRAM_ALLOCATOR_H_ + +#include "api/dram_buffer.h" +#include "driver/memory/dram_allocator.h" +#include "port/errors.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// A DRAM allocator to be used for chips that do not have an on-chip DRAM. +class NullDramAllocator : public DramAllocator { + public: + NullDramAllocator() = default; + ~NullDramAllocator() override = default; + + // This class is neither copyable nor movable. + NullDramAllocator(const NullDramAllocator&) = delete; + NullDramAllocator& operator=(const NullDramAllocator&) = delete; + + // Always returns an error. + + util::Status Open() override { + return util::OkStatus(); + } + util::Status Close() override { + return util::OkStatus(); + } + + util::StatusOr> AllocateBuffer( + size_t size_bytes) override { + return util::FailedPreconditionError("No on-chip DRAM available."); + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MEMORY_NULL_DRAM_ALLOCATOR_H_ diff --git a/driver/mmio/BUILD b/driver/mmio/BUILD new file mode 100644 index 0000000..8a986b1 --- /dev/null +++ b/driver/mmio/BUILD @@ -0,0 +1,51 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Memory Mapped IO (ex: PCIe.) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "coherent_allocator", + srcs = ["coherent_allocator.cc"], + hdrs = ["coherent_allocator.h"], + deps = [ + "//api:buffer", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +cc_library( + name = "host_queue", + hdrs = ["host_queue.h"], + deps = [ + ":coherent_allocator", + "//api:buffer", + "//driver:device_buffer", + "//driver:hardware_structures", + "//driver/config", + "//driver/memory:address_space", + "//driver/memory:dma_direction", + "//driver/registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) diff --git a/driver/mmio/coherent_allocator.cc b/driver/mmio/coherent_allocator.cc new file mode 100644 index 0000000..cfd60b3 --- /dev/null +++ b/driver/mmio/coherent_allocator.cc @@ -0,0 +1,109 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/mmio/coherent_allocator.h" + +#include "api/buffer.h" +#include "port/aligned_malloc.h" +#include "port/errors.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +constexpr const size_t kDefaultMaxCoherentBytes = 0x10000; +constexpr const size_t kDefaultAlignmentBytes = 8; + +} // namespace + +CoherentAllocator::CoherentAllocator(int alignment_bytes, size_t size_bytes) + : alignment_bytes_(alignment_bytes), total_size_bytes_(size_bytes) { + CHECK_GT(total_size_bytes_, 0); +} + +CoherentAllocator::CoherentAllocator() + : CoherentAllocator(kDefaultAlignmentBytes, kDefaultMaxCoherentBytes) {} + +util::Status CoherentAllocator::Open() { + StdMutexLock lock(&mutex_); + if (coherent_memory_base_ != nullptr) { + return util::FailedPreconditionError("Device already open."); + } + + ASSIGN_OR_RETURN(coherent_memory_base_, DoOpen(total_size_bytes_)); + + return util::Status(); // OK +} + +util::StatusOr CoherentAllocator::DoOpen(size_t size_bytes) { + char *mem_base = + static_cast(aligned_malloc(total_size_bytes_, alignment_bytes_)); + if (mem_base == nullptr) { + return util::FailedPreconditionError( + StringPrintf("Could not malloc %zu bytes.", total_size_bytes_)); + } + memset(mem_base, 0, size_bytes); + return mem_base; // OK +} + +util::StatusOr CoherentAllocator::Allocate(size_t size_bytes) { + StdMutexLock lock(&mutex_); + if (size_bytes == 0) { + return util::FailedPreconditionError("Allocate null size."); + } + + if (coherent_memory_base_ == nullptr) { + return util::FailedPreconditionError("Not Opened."); + } + + if ((allocated_bytes_ + size_bytes) > total_size_bytes_) { + return util::FailedPreconditionError(StringPrintf( + "CoherentAllocator: Allocate size = %zu and no memory (total = %zu).", + size_bytes, total_size_bytes_)); + } + + char *p = coherent_memory_base_ + allocated_bytes_; + + // Power of 2 pointer arithmetic: align the block boundary on chip specific + // byte alignment + size_t mask = alignment_bytes_ - 1; + allocated_bytes_ += (size_bytes + mask) & ~mask; + + return Buffer(p, size_bytes); +} + +util::Status CoherentAllocator::DoClose(char *mem_base, size_t size_bytes) { + if (mem_base != nullptr) { + aligned_free(mem_base); + } + return util::Status(); // OK +} + +util::Status CoherentAllocator::Close() { + StdMutexLock lock(&mutex_); + auto status = DoClose(coherent_memory_base_, total_size_bytes_); + // Resets state. + allocated_bytes_ = 0; + coherent_memory_base_ = nullptr; + + return status; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/mmio/coherent_allocator.h b/driver/mmio/coherent_allocator.h new file mode 100644 index 0000000..7d88d4a --- /dev/null +++ b/driver/mmio/coherent_allocator.h @@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MMIO_COHERENT_ALLOCATOR_H_ +#define DARWINN_DRIVER_MMIO_COHERENT_ALLOCATOR_H_ + +#include +#include // NOLINT + +#include "api/buffer.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Manage Device Specific DMA-able Coherent memory. +class CoherentAllocator { + public: + CoherentAllocator(); + CoherentAllocator(int alignment_bytes, size_t size_bytes); + virtual ~CoherentAllocator() = default; + + // Opens coherent allocator. + util::Status Open() LOCKS_EXCLUDED(mutex_); + + // Closes coherent allocator. + util::Status Close() LOCKS_EXCLUDED(mutex_); + + // Returns a chunk of coherent memory. + util::StatusOr Allocate(size_t size_bytes) LOCKS_EXCLUDED(mutex_); + + protected: + // Implements Open. + virtual util::StatusOr DoOpen(size_t size_bytes); + + // Implements close. + virtual util::Status DoClose(char *mem_base, size_t size_bytes); + + private: + // Alignment bytes for host memory. + const int alignment_bytes_; + + // User-space virtual address of memory block. + char *coherent_memory_base_{nullptr}; + + // Total size of coherent memory region. + const size_t total_size_bytes_; + + // Coherent Bytes allocated so far. + size_t allocated_bytes_ GUARDED_BY(mutex_){0}; + + // Guards all APIs functions Open/Close/Allocate. + mutable std::mutex mutex_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MMIO_COHERENT_ALLOCATOR_H_ diff --git a/driver/mmio/host_queue.h b/driver/mmio/host_queue.h new file mode 100644 index 0000000..94df964 --- /dev/null +++ b/driver/mmio/host_queue.h @@ -0,0 +1,434 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MMIO_HOST_QUEUE_H_ +#define DARWINN_DRIVER_MMIO_HOST_QUEUE_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "api/buffer.h" +#include "driver/config/chip_structures.h" +#include "driver/config/common_csr_helper.h" +#include "driver/config/queue_csr_offsets.h" +#include "driver/device_buffer.h" +#include "driver/hardware_structures.h" +#include "driver/memory/address_space.h" +#include "driver/memory/dma_direction.h" +#include "driver/mmio/coherent_allocator.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/math_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/thread_annotations.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// This class provides high level interface to manage host queue. +template +class HostQueue { + public: + typedef Element queue_element_type; + typedef StatusBlock status_block_type; + + static const uint64 kEnableBit = 1; + static const uint64 kDisableBit = 0; + + HostQueue(const config::QueueCsrOffsets& csr_offsets, + const config::ChipStructures& chip_structures, Registers* registers, + std::unique_ptr coherent_allocator, int size, + bool single_descriptor_mode); + + // This class is neither copyable nor movable. + HostQueue(const HostQueue&) = delete; + HostQueue& operator=(const HostQueue&) = delete; + + virtual ~HostQueue() = default; + + // Open/Close the host queue interface. + virtual util::Status Open(AddressSpace* address_space); + virtual util::Status Close(bool in_error); + util::Status Close() { return Close(/*in_error=*/false); } + + // Enqueue the element into the queue with a callback. Does not block. Returns + // a failure if Enqueue is called when the queue is full. + virtual util::Status Enqueue(const Element& element, + std::function callback); + + // Enable/Disable interrupts. + virtual util::Status EnableInterrupts() { + return RegisterWrite(csr_offsets_.queue_int_control, kEnableBit); + } + virtual util::Status DisableInterrupts() { + return RegisterWrite(csr_offsets_.queue_int_control, kDisableBit); + } + + // Process status block to advance |completed_head_|. For each completed + // element, invoke registered callback. + void ProcessStatusBlock() LOCKS_EXCLUDED(queue_mutex_); + + // Return available space in the queue. + virtual int GetAvailableSpace() const LOCKS_EXCLUDED(queue_mutex_) { + StdMutexLock lock(&queue_mutex_); + return GetAvailableSpaceLocked(); + } + + // Returns the size of the queue. + int size() const { return size_; } + + // Returns true if "address" is within queue address. + bool IsQueueAddress(void* address) const { + return address >= queue_ && address < queue_ + size_; + } + + // Returns true if "address" is within queue and align with a start address of + // a queue entry. + bool IsValidQueueEntry(void *address) const { + if (!IsQueueAddress(address)) return false; + return ( + reinterpret_cast(address) - reinterpret_cast(queue_) + ) % sizeof(Element) == 0; + } + + // Returns true if "address" corresponds to status block address. + bool IsStatusBlockAddress(void* address) const { + return address == status_block_; + } + + private: + // Returns an error if |open_| is not in the specified state. + util::Status CheckState(bool required) const + SHARED_LOCKS_REQUIRED(open_mutex_) { + if (open_ != required) { + return util::FailedPreconditionError("Invalid state in HostQueue."); + } + return util::Status(); // OK + } + + // Helper method to read register at a given offset. + util::StatusOr RegisterRead(uint64 offset) + LOCKS_EXCLUDED(open_mutex_) { + { + StdMutexLock lock(&open_mutex_); + RETURN_IF_ERROR(CheckState(/*required=*/true)); + } + return registers_->Read(offset); + } + + // Helper method to read register at a given offset. + util::Status RegisterWrite(uint64 offset, uint64 value) + LOCKS_EXCLUDED(open_mutex_) { + { + StdMutexLock lock(&open_mutex_); + RETURN_IF_ERROR(CheckState(/*required=*/true)); + } + return registers_->Write(offset, value); + } + + // Helper method to map all the device addresses. + void MapAll() { + DmaDirection dir = DmaDirection::kBidirectional; + Buffer host_queue(queue_, size_ * sizeof(Element)); + device_queue_buffer_ = + address_space_->MapCoherentMemory(host_queue, dir, + MappingTypeHint::kSimple) + .ValueOrDie(); + + VLOG(3) << StringPrintf("Queue base : %p -> 0x%016llx [%lu bytes]", queue_, + device_queue_buffer_.device_address(), + device_queue_buffer_.size_bytes()); + + Buffer host_status_block(status_block_, sizeof(StatusBlock)); + device_status_block_buffer_ = + address_space_ + ->MapCoherentMemory(host_status_block, dir, + MappingTypeHint::kSimple).ValueOrDie(); + + VLOG(3) << StringPrintf("Queue status block : %p -> 0x%016llx [%lu bytes]", + status_block_, + device_status_block_buffer_.device_address(), + device_status_block_buffer_.size_bytes()); + } + + // Helper method to unmap all the device addresses. + util::Status UnmapAll() { + RETURN_IF_ERROR( + address_space_->UnmapCoherentMemory(std::move(device_queue_buffer_))); + RETURN_IF_ERROR( + address_space_->UnmapCoherentMemory( + std::move(device_status_block_buffer_))); + return util::OkStatus(); + } + + // Helper method to return available space in the queue. Because this is + // circular queue, only (|size_| - 1) elements are available even if nothing + // has been enqueued. + int GetAvailableSpaceLocked() const SHARED_LOCKS_REQUIRED(queue_mutex_) { + if (single_descriptor_mode_) { + return completed_head_ == tail_ ? 1 : 0; + } else { + // Equivalent to: + // (tail_ >= completed_head_) ? (size_ - 1 - (tail_ - completed_head_)) + // : (completed_head_ - 1 - tail_); + return (completed_head_ - tail_ - 1) & (size_ - 1); + } + } + + // Guards open state. + mutable std::mutex open_mutex_; + + // Tracks open state. + bool open_ GUARDED_BY(open_mutex_){false}; + + // If true, only allow one outstanding descriptor at a time. + const bool single_descriptor_mode_{false}; + + // Guards queue state such as |tail_|. + mutable std::mutex queue_mutex_; + + // Gaurds the state for when callbacks are executing. + mutable std::mutex callback_mutex_; + + // Variables to control queue. + int completed_head_ GUARDED_BY(queue_mutex_){0}; + int tail_ GUARDED_BY(queue_mutex_){0}; + + // Configuration containing all the offsets related to the host queue. + const config::QueueCsrOffsets csr_offsets_; + + // Register interface to perform read/write on. + Registers* const registers_; + + // Coherent Allocator interface to get coherent mem DMA'able by our device. + std::unique_ptr coherent_allocator_; + + // Size of the HostQueue with respect to the number of |Element|. + const int size_; + + // Aligned storage and queue pointer for |Element|. + Element* queue_{nullptr}; + + // Aligned storage and pointer for |StatusBlock|. + StatusBlock* status_block_{nullptr}; + + // Callbacks when the enqueued element is done. Error status from status block + // is passed as an argument. + std::vector> callbacks_ GUARDED_BY(queue_mutex_); + + // Device addresses. + DeviceBuffer device_queue_buffer_; + DeviceBuffer device_status_block_buffer_; + + // Manages device virtual address space. + AddressSpace* address_space_{nullptr}; +}; + +template +HostQueue::HostQueue( + const config::QueueCsrOffsets& csr_offsets, + const config::ChipStructures& chip_structures, Registers* registers, + std::unique_ptr coherent_allocator, int size, + bool single_descriptor_mode) + : single_descriptor_mode_(single_descriptor_mode), + csr_offsets_(csr_offsets), + registers_(registers), + coherent_allocator_(std::move(coherent_allocator)), + size_(size), + callbacks_(size) { + CHECK(registers != nullptr); + // |size_| is power of 2. + CHECK_EQ(size_ & (size_ - 1), 0); + VLOG(3) << "Starting in " + << (single_descriptor_mode ? "single descriptor" : "normal") + << " mode"; +} + +template +util::Status HostQueue::Open( + AddressSpace* address_space) { + StdMutexLock lock(&open_mutex_); + RETURN_IF_ERROR(CheckState(/*required=*/false)); + + if (address_space_ != nullptr) { + return util::InternalError("Address space is already set."); + } + if (address_space == nullptr) { + return util::InvalidArgumentError("Provided address space is null."); + } + address_space_ = address_space; + + // Check for pre-conditions to setup host queue correctly. + ASSIGN_OR_RETURN(auto descriptor_result, + registers_->Read(csr_offsets_.queue_descriptor_size)); + if (descriptor_result != sizeof(Element)) { + return util::InternalError( + "Size of |Element| does not match with the hardware."); + } + + const size_t q_size = + kHostPageSize * + (MathUtil::CeilOfRatio((sizeof(Element) * size_), kHostPageSize)); + const size_t sb_size = + kHostPageSize * + (MathUtil::CeilOfRatio(sizeof(StatusBlock), kHostPageSize)); + + RETURN_IF_ERROR(coherent_allocator_->Open()); + + ASSIGN_OR_RETURN(Buffer queue_mem, coherent_allocator_->Allocate(q_size)); + ASSIGN_OR_RETURN(Buffer status_block_mem, + coherent_allocator_->Allocate(sb_size)); + + queue_ = reinterpret_cast(queue_mem.ptr()); + status_block_ = reinterpret_cast(status_block_mem.ptr()); + + // Allocate device addresses. + MapAll(); + + // Setup queue. + auto status = registers_->Write(csr_offsets_.queue_base, + device_queue_buffer_.device_address()); + status.Update( + registers_->Write(csr_offsets_.queue_status_block_base, + device_status_block_buffer_.device_address())); + status.Update(registers_->Write(csr_offsets_.queue_size, size_)); + if (!status.ok()) { + status.Update(UnmapAll()); + return status; + } + + // Enable the queue, and wait until it's actually enabled. + config::registers::QueueControl control; + control.set_enable(kEnableBit); + control.set_sb_wr_enable(kEnableBit); + RETURN_IF_ERROR(registers_->Write(csr_offsets_.queue_control, control.raw())); + RETURN_IF_ERROR(registers_->Poll(csr_offsets_.queue_status, kEnableBit)); + + open_ = true; + return util::Status(); // OK +} + +template +util::Status HostQueue::Close(bool in_error) { + StdMutexLock lock(&open_mutex_); + StdMutexLock callback_lock(&callback_mutex_); + RETURN_IF_ERROR(CheckState(/*required=*/true)); + + // Disable the queue. + RETURN_IF_ERROR(registers_->Write(csr_offsets_.queue_control, kDisableBit)); + if (!in_error) { + RETURN_IF_ERROR(registers_->Poll(csr_offsets_.queue_status, 0)); + } + + // Tail is software-write only, and is not reset by the hardware. + auto status = registers_->Write(csr_offsets_.queue_tail, 0); + // Reset device addresses. + status.Update(registers_->Write(csr_offsets_.queue_base, 0)); + status.Update(registers_->Write(csr_offsets_.queue_status_block_base, 0)); + RETURN_IF_ERROR(status); + + // Unmap memory. + RETURN_IF_ERROR(UnmapAll()); + + if (address_space_ == nullptr) { + return util::InternalError("Address space is already null."); + } + + address_space_ = nullptr; + status_block_ = nullptr; + queue_ = nullptr; + completed_head_ = 0; + tail_ = 0; + + // Release coherent memory block. + RETURN_IF_ERROR(coherent_allocator_->Close()); + + open_ = false; + return util::Status(); // OK +} + +template +util::Status HostQueue::Enqueue( + const Element& element, std::function callback) { + TRACE_SCOPE("HostQueue::Enqueue"); + StdMutexLock lock(&queue_mutex_); + if (GetAvailableSpaceLocked() == 0) { + return util::UnavailableError(StringPrintf( + "No space in the queue, completed_head: %d, tail: %d, size: %d", + completed_head_, tail_, size_)); + } + + VLOG(3) << "Adding an element to the host queue."; + + queue_[tail_] = element; + callbacks_[tail_] = std::move(callback); + + ++tail_; + tail_ &= (size_ - 1); + + RETURN_IF_ERROR(RegisterWrite(csr_offsets_.queue_tail, tail_)); + return util::Status(); // OK +} + +template +void HostQueue::ProcessStatusBlock() { + StdMutexLock callback_lock(&callback_mutex_); + int completed = 0; + + StatusBlock status_block = *status_block_; + const int completed_until = status_block.completed_head_pointer; + const uint32 error_status = status_block.fatal_error; + + std::vector> dones; + { + StdMutexLock lock(&queue_mutex_); + while (completed_head_ != completed_until) { + ++completed; + + if (callbacks_[completed_head_]) { + dones.push_back(std::move(callbacks_[completed_head_])); + } + ++completed_head_; + completed_head_ &= (size_ - 1); + } + VLOG(3) << "Completed " << completed << " elements."; + } + + // Clear interrupt pending. + CHECK_OK(RegisterWrite(csr_offsets_.queue_int_status, 0)); + + // Perform callbacks. + for (const auto& done : dones) { + done(error_status); + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MMIO_HOST_QUEUE_H_ diff --git a/driver/mmio_driver.cc b/driver/mmio_driver.cc new file mode 100644 index 0000000..1213538 --- /dev/null +++ b/driver/mmio_driver.cc @@ -0,0 +1,559 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/mmio_driver.h" + +#include +#include +#include +#include +#include +#include + +#include "api/buffer.h" +#include "api/watchdog.h" +#include "driver/config/common_csr_helper.h" +#include "driver/config/register_constants.h" +#include "driver/device_buffer.h" +#include "driver/device_buffer_mapper.h" +#include "driver/dma_info_extractor.h" +#include "driver/hardware_structures.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/interrupt/interrupt_handler.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/memory/address_utilities.h" +#include "driver/memory/mmu_mapper.h" +#include "driver/mmio/host_queue.h" +#include "driver/package_registry.h" +#include "driver/single_tpu_request.h" +#include "driver/time_stamper/driver_time_stamper.h" +#include "driver/top_level_handler.h" +#include "driver/tpu_request.h" +#include "executable/executable_generated.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { + +// Indicates no HIB Fatal Error. +constexpr uint64 kHibErrorStatusNone = 0; + +} // namespace + +MmioDriver::MmioDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::unique_ptr registers, + std::unique_ptr dram_allocator, + std::unique_ptr mmu_mapper, + std::unique_ptr address_space, + std::unique_ptr allocator, + std::unique_ptr> + instruction_queue, + std::unique_ptr interrupt_handler, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr scalar_core_controller, + std::unique_ptr run_controller, + std::unique_ptr top_level_handler, + std::unique_ptr executable_registry, + std::unique_ptr time_stamper) + : Driver( + [](config::ChipConfig* chip_config) { + CHECK(chip_config != nullptr); + return chip_config->GetChip(); + }(chip_config.get()), + std::move(executable_registry), driver_options, + std::move(time_stamper)), + hib_user_csr_offsets_(chip_config->GetHibUserCsrOffsets()), + hib_kernel_csr_offsets_(chip_config->GetHibKernelCsrOffsets()), + chip_structure_(chip_config->GetChipStructures()), + registers_(std::move(registers)), + dram_allocator_(std::move(dram_allocator)), + mmu_mapper_(std::move(mmu_mapper)), + address_space_(std::move(address_space)), + allocator_(std::move(allocator)), + instruction_queue_(std::move(instruction_queue)), + interrupt_handler_(std::move(interrupt_handler)), + top_level_interrupt_manager_(std::move(top_level_interrupt_manager)), + fatal_error_interrupt_controller_( + std::move(fatal_error_interrupt_controller)), + scalar_core_controller_(std::move(scalar_core_controller)), + run_controller_(std::move(run_controller)), + top_level_handler_(std::move(top_level_handler)), + dma_info_extractor_(DmaInfoExtractor::ExtractorType::kInstructionDma), + // TODO : Check reusing driver time_stamper for scheduler. + dma_scheduler_(api::Watchdog::MakeWatchdog( + driver_options.watchdog_timeout_ns(), + [this](int64) { HandleWatchdogTimeout(); }), + gtl::MakeUnique()), + chip_config_(std::move(chip_config)) {} + +MmioDriver::~MmioDriver() { + CHECK_OK(UnregisterAll()); + if (Close(api::Driver::ClosingMode::kGraceful).ok()) { + LOG(WARNING) << "Driver destroyed when open. Forced Close()."; + } +} + +util::Status MmioDriver::ValidateState(State expected_state) const { + if (state_ != expected_state) { + return util::FailedPreconditionError( + StringPrintf("Bad MMIO driver state. expected=%d, actual=%d.", + expected_state, state_)); + } + return util::Status(); // OK +} + +util::Status MmioDriver::SetState(State next_state) { + switch (state_) { + case kOpen: + if (next_state == kClosing) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosing: + if (next_state == kClosed) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosed: + if (next_state == kOpen) { + state_ = next_state; + return util::Status(); // OK + } + break; + } + + // Illegal state transition. + return util::FailedPreconditionError(StringPrintf( + "Invalid state transition. current=%d, next=%d.", state_, next_state)); +} + +util::Status MmioDriver::RegisterAndEnableAllInterrupts() { + // Instruction queue completion. + RETURN_IF_ERROR(interrupt_handler_->Register( + DW_INTERRUPT_INSTR_QUEUE, + std::bind(&HostQueue::ProcessStatusBlock, + instruction_queue_.get()))); + + // Execution completions. + RETURN_IF_ERROR( + interrupt_handler_->Register(DW_INTERRUPT_SC_HOST_0, [this]() { + // We need to clear the interrupts _before_ both: + // - reading interrupt counts, otherwise the device may concurrently + // increment interrupt count without signaling an interrupt. Driver + // can miss the completion event in this case. + // - calling HandleExecutionCompletion() because that may put the + // device in clock gated mode, which causes CSR access to be + // rejected. + CHECK_OK(scalar_core_controller_->ClearInterruptStatus(0)); + + auto count_result = scalar_core_controller_->CheckInterruptCounts(0); + CHECK_OK(count_result.status()); + uint64 count = count_result.ValueOrDie(); + for (int i = 0; i < count; ++i) { + HandleExecutionCompletion(); + } + })); + + // Clear status for other scalar core interrupts. + RETURN_IF_ERROR( + interrupt_handler_->Register(DW_INTERRUPT_SC_HOST_1, [this]() { + CHECK_OK(scalar_core_controller_->ClearInterruptStatus(1)); + })); + RETURN_IF_ERROR( + interrupt_handler_->Register(DW_INTERRUPT_SC_HOST_2, [this]() { + CHECK_OK(scalar_core_controller_->ClearInterruptStatus(2)); + })); + RETURN_IF_ERROR( + interrupt_handler_->Register(DW_INTERRUPT_SC_HOST_3, [this]() { + CHECK_OK(scalar_core_controller_->ClearInterruptStatus(3)); + })); + + // Top level interrupts. + for (int i = 0; i < top_level_interrupt_manager_->NumInterrupts(); ++i) { + RETURN_IF_ERROR(interrupt_handler_->Register( + static_cast(DW_INTERRUPT_TOP_LEVEL_BASE + i), [this, i]() { + LOG(WARNING) << StringPrintf("Top level interrupt: %d", i); + CHECK_OK(top_level_interrupt_manager_->HandleInterrupt(i)); + })); + } + + // HIB Errors. + RETURN_IF_ERROR( + interrupt_handler_->Register(DW_INTERRUPT_FATAL_ERR, [this]() { + // Fatal Error is sticky when raised. Once fatal error is raised, + // disable first and then clear interrupts. Note that it is still + // possible for this function to be called multiple times when fatal + // error is raised because of the host side delay involved in disabling + // and clearing the interrupts. This is handle inside CheckFatalError(). + CHECK_OK(fatal_error_interrupt_controller_->DisableInterrupts()); + CHECK_OK(fatal_error_interrupt_controller_->ClearInterruptStatus(0)); + CheckFatalError(CheckHibError()); + })); + + // Enable interrupts, if needed. + RETURN_IF_ERROR(scalar_core_controller_->EnableInterrupts()); + RETURN_IF_ERROR(instruction_queue_->EnableInterrupts()); + RETURN_IF_ERROR(fatal_error_interrupt_controller_->EnableInterrupts()); + + // TODO: refactor for Darwinn 1.0 vs 2.0 driver. + RETURN_IF_ERROR(top_level_interrupt_manager_->EnableInterrupts()); + + return util::Status(); // OK +} + +util::Status MmioDriver::CheckHibError() { + ASSIGN_OR_RETURN(uint64 hib_error_status, + registers_->Read(hib_user_csr_offsets_.hib_error_status)); + if (hib_error_status == kHibErrorStatusNone) { + return util::Status(); // OK + } + + uint64 hib_first_error_status = + registers_->Read(hib_user_csr_offsets_.hib_first_error_status) + .ValueOrDie(); + + auto error_string = StringPrintf( + "HIB Error. hib_error_status = %016llx, hib_first_error_status = %016llx", + static_cast(hib_error_status), // NOLINT(runtime/int) + static_cast( // NOLINT(runtime/int) + hib_first_error_status)); + LOG(ERROR) << error_string; + return util::InternalError(error_string); +} + +util::Status MmioDriver::DoOpen(bool debug_mode) { + StdMutexLock state_lock(&state_mutex_); + RETURN_IF_ERROR(ValidateState(/*expected_state=*/kClosed)); + + // Register Access. + RETURN_IF_ERROR(registers_->Open()); + auto registers_closer = + MakeCleanup([this] { CHECK_OK(registers_->Close()); }); + + // Reset Handler - Manages power state of the chip. + RETURN_IF_ERROR(top_level_handler_->Open()); + auto top_level_handler_closer = + MakeCleanup([this] { CHECK_OK(top_level_handler_->Close()); }); + + // Disable clock gate and reset GCB for clean state. + RETURN_IF_ERROR(top_level_handler_->DisableSoftwareClockGate()); + RETURN_IF_ERROR(top_level_handler_->DisableHardwareClockGate()); + RETURN_IF_ERROR(top_level_handler_->EnableReset()); + + // Quit from reset mode. + RETURN_IF_ERROR(top_level_handler_->QuitReset()); + RETURN_IF_ERROR(top_level_handler_->EnableHardwareClockGate()); + + // HIB should be good to start with. + RETURN_IF_ERROR(CheckHibError()); + + // Limit AXI DMA burst. + if (hib_user_csr_offsets_.dma_burst_limiter != + kCsrRegisterSpaceInvalidOffset) { + RETURN_IF_ERROR(registers_->Write(hib_user_csr_offsets_.dma_burst_limiter, + chip_structure_.axi_dma_burst_limiter)); + } else { + RETURN_IF_ERROR(registers_->Write(hib_kernel_csr_offsets_.dma_burst_limiter, + chip_structure_.axi_dma_burst_limiter)); + } + + // MMU Access. + const int num_simple_entries = + GetNumSimplePageTableEntries(chip_structure_.num_page_table_entries); + + RETURN_IF_ERROR(mmu_mapper_->Open(num_simple_entries)); + auto mmu_mapper_closer = + MakeCleanup([this] { CHECK_OK(mmu_mapper_->Close()); }); + + // Interrupt Handler. + RETURN_IF_ERROR(interrupt_handler_->Open()); + auto interrupt_handler_closer = + MakeCleanup([this] { CHECK_OK(interrupt_handler_->Close()); }); + + // Instruction Queue Access. + RETURN_IF_ERROR(instruction_queue_->Open(address_space_.get())); + auto instruction_queue_closer = + MakeCleanup([this] { CHECK_OK(instruction_queue_->Close()); }); + + // Scalar core control. + RETURN_IF_ERROR(scalar_core_controller_->Open()); + auto scalar_core_controller_closer = + MakeCleanup([this] { CHECK_OK(scalar_core_controller_->Close()); }); + + if (!debug_mode) { + // Move all subsystems to Run state. + RETURN_IF_ERROR(run_controller_->DoRunControl(RunControl::kMoveToRun)); + } + + // Disable periodic status block updates. + // TODO: refactor for Darwinn 1.0 vs 2.0 driver. + RETURN_IF_ERROR( + registers_->Write(hib_user_csr_offsets_.status_block_update, 0)); + + // Register and enable all interrupts. + RETURN_IF_ERROR(RegisterAndEnableAllInterrupts()); + + // DMA scheduler. + RETURN_IF_ERROR(dma_scheduler_.Open()); + auto dma_scheduler_closer = MakeCleanup([this] { + CHECK_OK(dma_scheduler_.Close(api::Driver::ClosingMode::kGraceful)); + }); + + // On-Chip DRAM allocator. + RETURN_IF_ERROR(dram_allocator_->Open()); + + // Errata registers. + // TODO: refactor for Darwinn 1.0 vs 2.0 driver. + RETURN_IF_ERROR(FixErrata()); + + // All good. Move state to open. + RETURN_IF_ERROR(SetState(kOpen)); + + // Clock gate until the first request arrives. + RETURN_IF_ERROR(top_level_handler_->EnableSoftwareClockGate()); + + // Release cleanup functions. + dma_scheduler_closer.release(); + scalar_core_controller_closer.release(); + interrupt_handler_closer.release(); + instruction_queue_closer.release(); + mmu_mapper_closer.release(); + top_level_handler_closer.release(); + registers_closer.release(); + + return util::Status(); // OK +} + +util::Status MmioDriver::DoClose(bool in_error, api::Driver::ClosingMode mode) { + StdMutexLock state_lock(&state_mutex_); + RETURN_IF_ERROR(ValidateState(/*expected_state=*/kOpen)); + + // Note our intention to close. + RETURN_IF_ERROR(SetState(kClosing)); + + // Disable Clock Gating so as the closing procedure can access the chip + RETURN_IF_ERROR(top_level_handler_->DisableSoftwareClockGate()); + + // All good. Shut down stuff. This is best effort. So if things starts + // failing, keep going and try cleaning up as much as we can. + util::Status status; + + // Pause all DMAs and wait for that to happen in the hardware otherwise we + // will be at risk of getting into undefined behavior in the following + // steps. + RETURN_IF_ERROR(PauseAllDmas()); + + // Stop all pipelines. + status.Update(run_controller_->DoRunControl(RunControl::kMoveToHalt)); + + // Disable all interrupts. + status.Update(top_level_interrupt_manager_->DisableInterrupts()); + status.Update(fatal_error_interrupt_controller_->DisableInterrupts()); + + status.Update(instruction_queue_->DisableInterrupts()); + status.Update(scalar_core_controller_->DisableInterrupts()); + + // We have to close interrupt handler before host queue especially for ASAP + // closing. Otherwise we may get interrupts that result in an Enqueue in host + // queue while it is closed. + status.Update(interrupt_handler_->Close( + in_error || mode == api::Driver::ClosingMode::kAsap)); + + status.Update(scalar_core_controller_->Close()); + status.Update(instruction_queue_->Close( + in_error || mode == api::Driver::ClosingMode::kAsap)); + + // Begin shutdown. + status.Update(dma_scheduler_.Close(mode)); + status.Update(UnmapAllParameters()); + status.Update(mmu_mapper_->Close()); + status.Update(top_level_handler_->EnableReset()); + status.Update(top_level_handler_->Close()); + status.Update(registers_->Close()); + status.Update(dram_allocator_->Close()); + RETURN_IF_ERROR(status); + + // Finalize. + RETURN_IF_ERROR(SetState(kClosed)); + + return util::Status(); // OK +} + +util::Status MmioDriver::DoCancelAndWaitRequests(bool in_error) { + StdMutexLock state_lock(&state_mutex_); + RETURN_IF_ERROR(dma_scheduler_.CancelPendingRequests()); + if (!in_error) { + RETURN_IF_ERROR(dma_scheduler_.WaitActiveRequests()); + } + return util::Status(); // OK +} + +Buffer MmioDriver::DoMakeBuffer(size_t size_bytes) const { + return allocator_->MakeBuffer(size_bytes); +} + +util::StatusOr MmioDriver::DoMapBuffer( + const Buffer& buffer, DmaDirection direction) { + if (buffer.IsValid()) { + ASSIGN_OR_RETURN(auto device_buffer, + address_space_->MapMemory(buffer, direction, + MappingTypeHint::kExtended)); + // TODO : this is dangerous: the std::bind captures a raw pointer to + // the address space. This will break if executable registry outlives + // address space in the driver. A better way is to at least use share_ptr + // for address spaces, and here let the std::bind capture a weak_ptr. + return MappedDeviceBuffer( + device_buffer, std::bind(&AddressSpace::UnmapMemory, + address_space_.get(), std::placeholders::_1)); + } + return MappedDeviceBuffer(); +} + +util::StatusOr> MmioDriver::DoCreateRequest( + const std::shared_ptr parent_request, + const ExecutableReference* executable, TpuRequest::RequestType type) { + TRACE_SCOPE("MmioDriver::DoCreateRequest"); + StdMutexLock lock(&state_mutex_); + RETURN_IF_ERROR(ValidateState(kOpen)); + return {std::make_shared( + next_id_++, parent_request, executable, allocator_.get(), + dram_allocator_.get(), + gtl::MakeUnique(address_space_.get()), + &dma_info_extractor_, chip_structure_.minimum_alignment_bytes, type)}; +} + +util::Status MmioDriver::DoSubmit(std::shared_ptr request) { + TRACE_SCOPE("MmioDriver::DoSubmit"); + StdMutexLock state_lock(&state_mutex_); + RETURN_IF_ERROR(ValidateState(kOpen)); + + // Disables Clock Gating so as the chip is accessible while the request + // is built. + RETURN_IF_ERROR(top_level_handler_->DisableSoftwareClockGate()); + + // Validate and prepare the request. + RETURN_IF_ERROR(request->Validate()); + RETURN_IF_ERROR(request->Prepare()); + + RETURN_IF_ERROR(dma_scheduler_.Submit(std::move(request))); + + TRACE_WITHIN_SCOPE("MmioDriver::DoSubmit::Issue"); + RETURN_IF_ERROR(TryIssueDmas()); + + return util::Status(); // OK +} + +util::Status MmioDriver::TryIssueDmas() { + TRACE_SCOPE("MmioDriver::TryIssueDmas"); + // Both the dma_scheduler and instruction_queue is threadsafe on its own. + // However, we also want to to make sure that DMAs popped from the dma + // scheduler are pushed to the instruction queue in the order it is received. + // So do the following with the dma_issue_mutex held. + StdMutexLock state_lock(&dma_issue_mutex_); + + CHECK_OK(top_level_handler_->DisableSoftwareClockGate()); + + while (instruction_queue_->GetAvailableSpace() > 0) { + ASSIGN_OR_RETURN(auto* dma, dma_scheduler_.GetNextDma()); + if (dma == nullptr) { + break; + } + CHECK(dma->type() == DmaDescriptorType::kInstruction); + + HostQueueDescriptor descriptor{}; + descriptor.address = dma->buffer().device_address(); + descriptor.size_in_bytes = dma->buffer().size_bytes(); + + // Enqueue should always succeed. + CheckFatalError( + instruction_queue_->Enqueue(descriptor, [this, dma](uint32 error_code) { + CHECK_OK(dma_scheduler_.NotifyDmaCompletion(dma)); + HandleHostQueueCompletion(error_code); + })); + + TRACE_WITHIN_SCOPE("MmioDriver::TryIssueDmas::Enqueue"); + } + + return util::OkStatus(); +} + +void MmioDriver::HandleExecutionCompletion() { + TRACE_SCOPE("MmioDriver::HandleExecutionCompletion"); + CHECK_OK(dma_scheduler_.NotifyRequestCompletion()); + HandleTpuRequestCompletion(); + if (dma_scheduler_.IsEmpty()) { + CHECK_OK(top_level_handler_->EnableSoftwareClockGate()); + } +} + +void MmioDriver::HandleHostQueueCompletion(uint32 error_code) { + TRACE_SCOPE("MmioDriver::HostQueueCompletion"); + if (error_code != 0) { + // TODO: Parse the error code and attach a human readable string. + CheckFatalError( + util::InternalError(StringPrintf("Host Queue error %d.", error_code))); + return; + } + CHECK_OK(TryIssueDmas()); +} + +void MmioDriver::CheckFatalError(const util::Status& status) { + if (status.ok()) { + return; + } + NotifyFatalError(status); +} + +util::Status MmioDriver::DoSetRealtimeMode(bool on) { + dma_scheduler_.SetRealtimeMode(on); + return util::OkStatus(); +} + +util::Status MmioDriver::PauseAllDmas() { + constexpr uint64 kPauseDmas = 1; + RETURN_IF_ERROR( + registers_->Write(hib_user_csr_offsets_.dma_pause, kPauseDmas)); + constexpr uint64 kAllDmasPaused = 1; + return registers_->Poll(hib_user_csr_offsets_.dma_paused, kAllDmasPaused); +} + +util::Status MmioDriver::FixErrata() { + return util::OkStatus(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/mmio_driver.h b/driver/mmio_driver.h new file mode 100644 index 0000000..f80501e --- /dev/null +++ b/driver/mmio_driver.h @@ -0,0 +1,266 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_MMIO_DRIVER_H_ +#define DARWINN_DRIVER_MMIO_DRIVER_H_ + +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "api/allocated_buffer.h" +#include "api/buffer.h" +#include "driver/allocator.h" +#include "driver/config/chip_config.h" +#include "driver/config/chip_structures.h" +#include "driver/config/hib_kernel_csr_offsets.h" +#include "driver/config/hib_user_csr_offsets.h" +#include "driver/device_buffer.h" +#include "driver/dma_info_extractor.h" +#include "driver/driver.h" +#include "driver/hardware_structures.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/interrupt/interrupt_handler.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/memory/address_space.h" +#include "driver/memory/dma_direction.h" +#include "driver/memory/dram_allocator.h" +#include "driver/memory/mmu_mapper.h" +#include "driver/mmio/host_queue.h" +#include "driver/package_registry.h" +#include "driver/real_time_dma_scheduler.h" +#include "driver/registers/registers.h" +#include "driver/run_controller.h" +#include "driver/scalar_core_controller.h" +#include "driver/time_stamper/time_stamper.h" +#include "driver/top_level_handler.h" +#include "driver/tpu_request.h" +#include "executable/executable_generated.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// DarwiNN driver implementation that talks to the device through memory-mapped +// IO setup with a kernel device driver. Thread safe. +class MmioDriver : public Driver { + public: + MmioDriver( + const api::DriverOptions& options, + std::unique_ptr chip_config, + std::unique_ptr registers, + std::unique_ptr dram_allocator, + std::unique_ptr mmu_mapper, + std::unique_ptr address_space, + std::unique_ptr allocator, + std::unique_ptr> + instruction_queue, + std::unique_ptr interrupt_handler, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr scalar_core_controller, + std::unique_ptr run_controller, + std::unique_ptr top_level_handler, + std::unique_ptr executable_registry, + std::unique_ptr time_stamper); + + // This class is neither copyable nor movable. + MmioDriver(const MmioDriver&) = delete; + MmioDriver& operator=(const MmioDriver&) = delete; + + ~MmioDriver() override; + + uint64_t allocation_alignment_bytes() const override { + return chip_structure_.allocation_alignment_bytes; + } + + protected: + util::Status DoOpen(bool debug_mode) LOCKS_EXCLUDED(state_mutex_) override; + util::Status DoClose(bool in_error, api::Driver::ClosingMode mode) + LOCKS_EXCLUDED(state_mutex_) override; + util::Status DoCancelAndWaitRequests(bool in_error) + LOCKS_EXCLUDED(state_mutex_) override; + + Buffer DoMakeBuffer(size_t size_bytes) const override; + + util::StatusOr DoMapBuffer( + const Buffer& buffer, DmaDirection direction) override; + + util::StatusOr> DoCreateRequest( + const std::shared_ptr parent_request, + const ExecutableReference* executable, TpuRequest::RequestType type) + LOCKS_EXCLUDED(state_mutex_) override; + + // We do support real-time mode in this driver. + bool HasImplementedRealtimeMode() const final { return true; } + + util::Status DoSetExecutableTiming(const ExecutableReference* executable, + const api::Timing& timing) final { + return dma_scheduler_.SetExecutableTiming(executable, timing); + } + + util::Status DoRemoveExecutableTiming(const ExecutableReference* executable) { + return dma_scheduler_.RemoveExecutableTiming(executable); + } + + util::Status DoSetRealtimeMode(bool on) final; + + util::Status DoSubmit(std::shared_ptr request) + LOCKS_EXCLUDED(state_mutex_) override; + + int64 MaxRemainingCycles() const override { + return dma_scheduler_.MaxRemainingCycles(); + } + + // Returns a pointer to the registers in this driver. The pointer is valid as + // long as the driver instance is. + Registers* registers() { return registers_.get(); } + + // Returns a reference to the chip config. It is valid as long as the + // MmioDriver instance is. + const config::ChipConfig& chip_config() const { + return *chip_config_; + } + + util::StatusOr> GetOldestActiveRequest() + const override { + return dma_scheduler_.GetOldestActiveRequest(); + } + + private: + // TODO: Eliminate state management here. Since this is now done + // in the base class. + // Driver state. Transitions : + // kClosed -> kOpen -> kClosing -> kClosed. + enum State { + kOpen, // Driver is Open. + kClosing, // Driver is Closing. + kClosed, // Driver is Closed. (Initial state.) + }; + + // Attempts a state transition to the given state. + util::Status SetState(State next_state) + EXCLUSIVE_LOCKS_REQUIRED(state_mutex_); + + // Validates that we are in the expected state. + util::Status ValidateState(State expected_state) const + SHARED_LOCKS_REQUIRED(state_mutex_); + + // Attempts to issue as many DMAs as possible. + util::Status TryIssueDmas() LOCKS_EXCLUDED(dma_issue_mutex_); + + // Handles request execution completions. + void HandleExecutionCompletion(); + + // Handles instruction queue pop notifications. + void HandleHostQueueCompletion(uint32 error_code); + + // Checks for HIB Errors. + util::Status CheckHibError(); + + // Catch all fatal error handling during runtime. + void CheckFatalError(const util::Status& status); + + // Registers and enables all interrupts. + util::Status RegisterAndEnableAllInterrupts(); + + // Pauses all the DMAs and returns once that is verified. + util::Status PauseAllDmas() EXCLUSIVE_LOCKS_REQUIRED(state_mutex_); + + // Programs errata CSRs to disable hardware features with known issues. + util::Status FixErrata(); + + // CSR offsets. + const config::HibUserCsrOffsets& hib_user_csr_offsets_; + const config::HibKernelCsrOffsets& hib_kernel_csr_offsets_; + + // Chip structure. + const config::ChipStructures& chip_structure_; + + // Register interface. + std::unique_ptr registers_; + + // The object responsible for allocating on-chip DRAM buffers (if supported). + std::unique_ptr dram_allocator_; + + // MMU Mapper. + std::unique_ptr mmu_mapper_; + + // Address space management. + std::unique_ptr address_space_; + + // Host buffer allocator. + std::unique_ptr allocator_; + + // Instruction queue. + std::unique_ptr> + instruction_queue_; + + // Interrupt handler. + std::unique_ptr interrupt_handler_; + + // Top level interrupt manager. + std::unique_ptr top_level_interrupt_manager_; + + // Fatal error interrupt controller. + std::unique_ptr + fatal_error_interrupt_controller_; + + // Scalar core controller. + std::unique_ptr scalar_core_controller_; + + // Run controller. + std::unique_ptr run_controller_; + + // Reset handler. + std::unique_ptr top_level_handler_; + + // Maintains integrity of the driver state. + std::mutex state_mutex_; + + // Ensures that DMAs produced by the dma scheduler is submitted + // in order to the instruction queue. + std::mutex dma_issue_mutex_; + + // Driver state. + State state_ GUARDED_BY(state_mutex_){kClosed}; + + // When in state |kClosing|, a notification to wait for all active + // requests to complete. + std::condition_variable wait_active_requests_complete_; + + // ID for tracking requests. + std::atomic next_id_{0}; + + // DMA info extractor. + DmaInfoExtractor dma_info_extractor_; + + // DMA scheduler. + RealTimeDmaScheduler dma_scheduler_; + + // Chip configuration. + std::unique_ptr chip_config_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_MMIO_DRIVER_H_ diff --git a/driver/package_registry.cc b/driver/package_registry.cc new file mode 100644 index 0000000..1dcffa8 --- /dev/null +++ b/driver/package_registry.cc @@ -0,0 +1,803 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/package_registry.h" + +#include +#include +#include +#include +#include + +#include "api/package_reference.h" +#include "api/runtime_version.h" +#include "driver/aligned_allocator.h" +#include "driver/package_verifier.h" +#include "executable/executable_generated.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Alignment for buffers allocated by the registry. +constexpr uint64 kAlignment = 4096; + +} // namespace + +PackageRegistry::PackageRegistry() : PackageRegistry(api::Chip::kUnknown) {} + +PackageRegistry::PackageRegistry(api::Chip chip) + : PackageRegistry(chip, gtl::MakeUnique(), nullptr) {} + +PackageRegistry::PackageRegistry( + api::Chip chip, std::unique_ptr executable_verifier, + DramAllocator* dram_allocator) + : allocator_(kAlignment), + dram_allocator_(dram_allocator), + chip_(chip), + verifier_(std::move(executable_verifier)) {} + +util::StatusOr> +PackageRegistry::GetExecutablesFromBinary(const char* executable_content, + size_t length) { + // Check the file identifier of the package. + std::string package_identifier( + flatbuffers::GetBufferIdentifier(executable_content), + flatbuffers::FlatBufferBuilder::kFileIdentifierLength); + if (package_identifier != api::kHeadPackageIdentifier) { + LOG(WARNING) << StringPrintf("Package file identifier %s not supported.", + package_identifier.c_str()); + } + + // Verify and get the package from the memory mapped buffer. + flatbuffers::Verifier package_verifier( + reinterpret_cast(executable_content), length); + if (!package_verifier.VerifyBuffer()) { + return util::InternalError("Package verification failed."); + } + auto* package = flatbuffers::GetRoot(executable_content); + + // The runtime version check shall always be the first after parsing, so it's + // possible to introduce non-backward-compatible changes. + const auto min_runtime_version = package->min_runtime_version(); + if (min_runtime_version < api::RuntimeVersion::kMinValidRuntimeVersion) { + LOG(WARNING) << StringPrintf( + "Minimum runtime version required by package (%d) is lower than " + "expected (%d).", + min_runtime_version, api::RuntimeVersion::kMinValidRuntimeVersion); + } else if (min_runtime_version > api::RuntimeVersion::kCurrent) { + return util::FailedPreconditionError(StringPrintf( + "Package requires runtime version (%d), which is newer " + "than this runtime version (%d).", + package->min_runtime_version(), api::RuntimeVersion::kCurrent)); + } + + constexpr int kVirtualChipIdForMultiChipPackage = -1; + if (package->virtual_chip_id() == kVirtualChipIdForMultiChipPackage) { + return util::FailedPreconditionError("This is a multi-chip package."); + } + + if (flatbuffers::VectorLength(package->serialized_multi_executable()) == 0) { + return util::FailedPreconditionError("No executables to register."); + } + + // Verify and get the MultiExecutable table from the package. + flatbuffers::Verifier multi_executable_verifier( + package->serialized_multi_executable()->data(), + flatbuffers::VectorLength(package->serialized_multi_executable())); + if (!multi_executable_verifier.VerifyBuffer()) { + return util::InternalError("MultiExecutable verification failed."); + } + auto* multi_executable = flatbuffers::GetRoot( + package->serialized_multi_executable()->data()); + + // Extract the buffer pointer for the serialized executable from the + // MultiExecutable. + + if (flatbuffers::VectorLength(multi_executable->serialized_executables()) == + 0) { + return util::NotFoundError("No executables provided."); + } + + return ExtractExecutables(*multi_executable); +} + +util::StatusOr +PackageRegistry::GetMainExecutableFromExecutableMap( + std::unordered_map executables) { + switch (executables.size()) { + case 1: + // TODO Here we are considering the sole executable in a + // package as stand-alone no matter what the type specifies. This is for + // being backward-compatible with the old-style parameter-caching. Once + // that is deprecated, here we should look for the STAND_ALONE type. + return executables.begin()->second; + + case 2: + return executables[ExecutableType_EXECUTION_ONLY]; + + // TODO Once this feature is implemented, we need to update the + // constructor used here. Right now we still allow 3 executables in a + // package to avoid future backward-incompatibility. The current behavior is + // to always use the stand-alone one. + case 3: + return executables[ExecutableType_STAND_ALONE]; + + default: + return util::InternalError("Unexpected combination of executables."); + } +} + +util::StatusOr +PackageRegistry::GetPCExecutableFromExecutableMap( + std::unordered_map executables) { + switch (executables.size()) { + case 1: + return nullptr; + case 2: + return executables[ExecutableType_PARAMETER_CACHING]; + case 3: + return nullptr; + default: + return util::InternalError("Unexpected combination of executables."); + } +} + +util::StatusOr PackageRegistry::RegisterPackage( + const Buffer& package_buffer) { + ASSIGN_OR_RETURN(auto executables, + GetExecutablesFromBinary( + reinterpret_cast(package_buffer.ptr()), + package_buffer.size_bytes())); + + for (const auto& it : executables) { + RETURN_IF_ERROR(VerifyExecutableMatchesChip(it.second)); + } + + ASSIGN_OR_RETURN(const Executable* main_executable, + GetMainExecutableFromExecutableMap(executables)); + ASSIGN_OR_RETURN(const Executable* parameter_caching_executable, + GetPCExecutableFromExecutableMap(executables)); + + PackageReference* package_reference; + + if (parameter_caching_executable != nullptr) { + package_reference = new PackageReference( + package_buffer, parameter_caching_executable, main_executable, + &allocator_, dram_allocator_, verifier_.get()); + } else { + package_reference = + new PackageReference(package_buffer, main_executable, &allocator_, + dram_allocator_, verifier_.get()); + } + + return SetRegistrations( + std::unique_ptr(package_reference)); +} + +util::StatusOr> +PackageRegistry::GetMainExecutableLayersInfoFromBinary( + const char* executable_content, size_t length) { + ASSIGN_OR_RETURN(auto executables, + GetExecutablesFromBinary(executable_content, length)); + + ASSIGN_OR_RETURN(const Executable* main_executable, + GetMainExecutableFromExecutableMap(executables)); + + return gtl::MakeUnique(main_executable); +} + +util::StatusOr> +PackageRegistry::ExtractExecutables(const MultiExecutable& multi_executable) { + std::unordered_map executables; + + // Fetch executables to a map of type -> executable. + for (const auto* executable_serialized : + *multi_executable.serialized_executables()) { + ASSIGN_OR_RETURN(auto executable, + FetchAndVerifyExecutable(executable_serialized->c_str(), + executable_serialized->size())); + + if (executables.find(executable->type()) != executables.end()) { + return util::InvalidArgumentError( + "Multiple executables of the same type were found in the package."); + } + executables[executable->type()] = executable; + } + + // Sanity check for legal combinations. + switch (executables.size()) { + case 0: + return util::InternalError("No executables provided."); + + case 1: + break; + + case 2: + if (executables.find(ExecutableType_PARAMETER_CACHING) == + executables.end() || + executables.find(ExecutableType_EXECUTION_ONLY) == + executables.end()) { + return util::InvalidArgumentError( + "Invalid combination of executables in the package."); + } + break; + + case 3: + if (executables.find(ExecutableType_PARAMETER_CACHING) == + executables.end() || + executables.find(ExecutableType_EXECUTION_ONLY) == + executables.end() || + executables.find(ExecutableType_STAND_ALONE) == executables.end()) { + return util::InvalidArgumentError( + "Invalid combination of executables in the package."); + } + break; + + default: + return util::InvalidArgumentError( + "Found executable types that are not yet supported."); + } + + return executables; +} + +util::StatusOr PackageRegistry::FetchAndVerifyExecutable( + const char* executable_serialized, size_t length) { + flatbuffers::Verifier verifier( + reinterpret_cast(executable_serialized), length); + if (!verifier.VerifyBuffer()) { + return util::InvalidArgumentError("Executable verification failed."); + } + + const auto* executable = flatbuffers::GetRoot( + reinterpret_cast(executable_serialized)); + + // All executables must have a batch size of at least one. + if (executable->batch_size() < 1) { + return util::InvalidArgumentError("Executable has invalid batch size."); + } + + return executable; +} + +util::Status PackageRegistry::VerifyExecutableMatchesChip( + const Executable* executable) const { + return util::OkStatus(); +} + +util::StatusOr +PackageRegistry::RegisterSerialized(const std::string& executable_content) { + return RegisterSerialized(executable_content.data(), + executable_content.size()); +} + +util::StatusOr +PackageRegistry::RegisterSerialized(const char* executable_content, + size_t length) { + Buffer package_buffer = allocator_.MakeBuffer(length); + CHECK(package_buffer.ptr() != nullptr); + memcpy(package_buffer.ptr(), executable_content, length); + return RegisterPackage(package_buffer); +} + +util::StatusOr PackageRegistry::RegisterFile( + const std::string& executable_filename) { + std::ifstream ifs; + ifs.open(executable_filename, std::ifstream::in); + if (!ifs.is_open()) { + return util::InvalidArgumentError( + StringPrintf("Cannot open %s.", executable_filename.c_str())); + } + + ifs.seekg(0, std::ios_base::end); + size_t file_size(ifs.tellg()); + ifs.seekg(std::ios_base::beg); + + Buffer package_buffer = allocator_.MakeBuffer(file_size); + CHECK(package_buffer.ptr() != nullptr); + ifs.read(reinterpret_cast(package_buffer.ptr()), file_size); + ifs.close(); + + return RegisterPackage(package_buffer); +} + +util::Status PackageRegistry::Unregister( + const api::PackageReference* package_reference) { + StdMutexLock registrations_lock(®istrations_mutex_); + + // Bail out early if package_reference isn't valid. + if (package_reference == nullptr) { + return util::InvalidArgumentError("Provided package reference in null."); + } + if (registrations_.count(package_reference) == 0) { + return util::NotFoundError( + "Attempting to unregister a nonexistent executable reference."); + } + + PackageReference* driver_package_ref = const_cast( + static_cast(package_reference)); + + ASSIGN_OR_RETURN(auto parameters_mapped, + driver_package_ref->ParametersMapped()); + if (parameters_mapped) { + RETURN_IF_ERROR(driver_package_ref->UnmapParameters()); + } + + // TODO : Need to track outstanding requests and error when + // there are pending/in-flight requests at un-registration time. + if (registrations_.erase(driver_package_ref) == 0) { + return util::NotFoundError( + "Attempting to unregister a nonexistent executable reference."); + } + + return util::Status(); // OK. +} + +util::Status PackageRegistry::UnregisterAll() { + RETURN_IF_ERROR(UnmapAllParameters()); + + StdMutexLock registrations_lock(®istrations_mutex_); + // TODO : Need to track outstanding requests and error when + // there are pending/in-flight requests at un-registration time. + registrations_.clear(); + + return util::OkStatus(); +} + +util::Status PackageRegistry::UnmapAllParameters() { + StdMutexLock registrations_lock(®istrations_mutex_); + util::Status status; + + for (auto& it : registrations_) { + if (it.first == nullptr) { + return util::InternalError( + "Encountered nullptr key to package reference."); + } + PackageReference* package = const_cast( + static_cast(it.first)); + + const auto parameters_mapped = package->ParametersMapped(); + if (!parameters_mapped.ok()) { + status.Update(parameters_mapped.status()); + continue; + } + + if (parameters_mapped.ValueOrDie()) { + status.Update(package->UnmapParameters()); + } + } + + return status; +} + +std::vector PackageRegistry::GetAllRegistrations() + const { + StdMutexLock registrations_lock(®istrations_mutex_); + + std::vector package_refs; + package_refs.reserve(registrations_.size()); + for (auto& registration : registrations_) { + package_refs.push_back(registration.second.get()); + } + return package_refs; +} + +const api::PackageReference* PackageRegistry::SetRegistrations( + std::unique_ptr api_package_ref) { + StdMutexLock registrations_lock(®istrations_mutex_); + + auto api_reference = + registrations_.emplace(api_package_ref.get(), std::move(api_package_ref)) + .first->first; + + return api_reference; +} + +void PackageRegistry::ResetParametersLoaded() { + StdMutexLock registrations_lock(®istrations_mutex_); + for (auto& registration : registrations_) { + auto package_ref = + static_cast(registration.second.get()); + for (auto exec_ref : package_ref->AllExecutableReferences()) { + exec_ref->ResetParametersLoaded(); + } + } +} + +ExecutableLayersInfo::ExecutableLayersInfo(const Executable* executable) { + // Set layer information. + const int input_layer_count = + flatbuffers::VectorLength(executable->input_layers()); + inputs_.reserve(input_layer_count); + input_layer_names_.reserve(input_layer_count); + + for (int i = 0; i < input_layer_count; ++i) { + const auto& layer_name = executable->input_layers()->Get(i)->name()->str(); + api::InputLayerInformation layer(executable->input_layers()->Get(i)); + if (layer.CacheOnDram()) { + needs_dram_in_layers_ = true; + } + inputs_.emplace_back(layer); + input_layer_names_.emplace_back(layer_name); + input_map_[layer_name] = i; + } + + const int output_layer_count = + flatbuffers::VectorLength(executable->output_layers()); + outputs_.reserve(output_layer_count); + output_layer_names_.reserve(output_layer_count); + + for (int i = 0; i < output_layer_count; ++i) { + const auto& layer_name = executable->output_layers()->Get(i)->name()->str(); + api::OutputLayerInformation layer(executable->output_layers()->Get(i)); + if (layer.CacheOnDram()) { + needs_dram_in_layers_ = true; + } + outputs_.emplace_back(layer); + output_layer_names_.emplace_back(layer_name); + output_map_[layer_name] = i; + } +} + +util::StatusOr ExecutableLayersInfo::InputIndex( + const std::string& name) const { + auto iter = input_map_.find(name); + if (iter != input_map_.end()) { + return iter->second; + } + return util::NotFoundError( + StringPrintf("Input layer '%s' not found.", name.c_str())); +} + +util::StatusOr ExecutableLayersInfo::OutputIndex( + const std::string& name) const { + auto iter = output_map_.find(name); + if (iter != output_map_.end()) { + return iter->second; + } + return util::NotFoundError( + StringPrintf("Output layer '%s' not found.", name.c_str())); +} + +const api::InputLayerInformation* ExecutableLayersInfo::InputLayer( + int index) const { + if (index < inputs_.size()) { + return &inputs_[index]; + } + return nullptr; +} + +const api::OutputLayerInformation* ExecutableLayersInfo::OutputLayer( + int index) const { + if (index < outputs_.size()) { + return &outputs_[index]; + } + return nullptr; +} + +// TODO If possible, refactor this to only have a name map and we +// will not need to do 2 lookups. +util::StatusOr +ExecutableLayersInfo::InputLayer(const std::string& layer_name) const { + ASSIGN_OR_RETURN(auto index, InputIndex(layer_name)); + const auto* input_info = InputLayer(index); + if (input_info == nullptr) { + return util::InternalError( + StringPrintf("Input layer %s was not found in executable reference.", + layer_name.c_str())); + } + return input_info; +} + +util::StatusOr +ExecutableLayersInfo::OutputLayer(const std::string& layer_name) const { + ASSIGN_OR_RETURN(auto index, OutputIndex(layer_name)); + const auto* output_info = OutputLayer(index); + if (output_info == nullptr) { + return util::InternalError( + StringPrintf("Output layer %s was not found in executable reference.", + layer_name.c_str())); + } + return output_info; +} + +util::StatusOr ExecutableLayersInfo::InputLayerSizeBytes( + const std::string& name) const { + ASSIGN_OR_RETURN(int index, InputIndex(name)); + return inputs_[index].ActualSizeBytes(); +} + +util::StatusOr ExecutableLayersInfo::InputLayerPaddedSizeBytes( + const std::string& name) const { + ASSIGN_OR_RETURN(int index, InputIndex(name)); + return inputs_[index].PaddedSizeBytes(); +} + +util::StatusOr ExecutableLayersInfo::OutputLayerSizeBytes( + const std::string& name) const { + ASSIGN_OR_RETURN(int index, OutputIndex(name)); + return outputs_[index].ActualSizeBytes(); +} + +ExecutableReference::ExecutableReference(const Executable* executable, + Allocator* allocator, + DramAllocator* dram_allocator, + PackageReference* pkg_ref) + : executable_(executable), + package_reference_(pkg_ref) { + // Create a buffer for parameters. This buffer is either in host or in the + // on-chip DRAM. If on host, we already have a copy of the data in the package + // flatbuffer. If on chip, we will copy the data the first time the buffer is + // mapped. + auto parameter_size_bytes = + flatbuffers::VectorLength(executable->parameters()); + if (parameter_size_bytes > 0) { + // TODO Remove the check on nullptr. + if (executable->use_tpu_dram_for_parameters() && + dram_allocator != nullptr) { + auto buffer_or_error = + dram_allocator->AllocateBuffer(parameter_size_bytes); + if (buffer_or_error.ok()) { + parameters_ = Buffer(std::move(buffer_or_error.ValueOrDie())); + needs_dram_ = true; + } else { + LOG(WARNING) << StringPrintf( + "Failed to allocate TPU DRAM buffer of size %zu " + "for parameters: ", + parameter_size_bytes) + << buffer_or_error.status().message(); + parameters_ = Buffer( + reinterpret_cast(executable->parameters()->data()), + parameter_size_bytes); + } + } else { + parameters_ = Buffer( + reinterpret_cast(executable->parameters()->data()), + parameter_size_bytes); + } + } + + // Allocate scratch if necessary. It is preferred to have scratch in the + // on-chip DRAM. + // + // TODO Check if the chip does have a DRAM. + if (executable->scratch_size_bytes() > 0) { + if (dram_allocator != nullptr) { + auto buffer_or_error = + dram_allocator->AllocateBuffer(executable->scratch_size_bytes()); + if (buffer_or_error.ok()) { + scratch_ = Buffer(std::move(buffer_or_error.ValueOrDie())); + needs_dram_ = true; + } else { + scratch_ = allocator->MakeBuffer(executable->scratch_size_bytes()); + } + } else { + scratch_ = allocator->MakeBuffer(executable->scratch_size_bytes()); + } + } + + // Extracts the input and output layers info from the executable binary. + executable_layers_info_ = gtl::MakeUnique(executable); + + // The DRAM will be needed if any of the component needs to access it. + if (executable_layers_info_->NeedsDramInLayers()) { + needs_dram_ = true; + } +} + +util::Status ExecutableReference::ValidateInput(const std::string& input_name, + const Buffer& input) const { + ASSIGN_OR_RETURN(const auto* layer, InputLayer(input_name)); + + // We can only accept buffers that are the same size as the input layer tensor + // with or without padding. + if (input.size_bytes() != layer->ActualSizeBytes() && + input.size_bytes() != layer->PaddedSizeBytes()) { + return util::InvalidArgumentError(StringPrintf( + "Unexpected input size for \"%s\". Expected %d or %d, got %zu", + input_name.c_str(), layer->ActualSizeBytes(), layer->PaddedSizeBytes(), + input.size_bytes())); + } + + return util::OkStatus(); +} + +util::Status ExecutableReference::ValidateOutput(const std::string& output_name, + const Buffer& output) const { + ASSIGN_OR_RETURN(const int expected_size_bytes, + OutputLayerSizeBytes(output_name)); + if (output.size_bytes() != expected_size_bytes) { + return util::InvalidArgumentError(StringPrintf( + "Unexpected output size for \"%s\". expected=%d, actual=%zu.", + output_name.c_str(), expected_size_bytes, output.size_bytes())); + } + return util::Status(); // OK +} + +// Reuses the instruction buffers if available. Creates a new one if not. +std::unique_ptr ExecutableReference::GetInstructionBuffers( + Allocator* const allocator) { + TRACE_SCOPE("ExecutableReference::GetInstructionBuffers"); + StdMutexLock lock(&instruction_buffers_vector_mutex_); + + if (!instruction_buffers_vector_.empty()) { + std::unique_ptr old_instruction_buffers = + std::move(instruction_buffers_vector_.back()); + instruction_buffers_vector_.pop_back(); + VLOG(10) << "Reusing old instruction buffers."; + + return old_instruction_buffers; + } + + auto instruction_buffers = gtl::MakeUnique( + allocator, *executable().instruction_bitstreams()); + + VLOG(10) << "Created new instruction buffers."; + return instruction_buffers; +} + +// Returns instruction buffers back to the executable references so that the +// next request could reuse it. +void ExecutableReference::ReturnInstructionBuffers( + std::unique_ptr instruction_buffers) { + StdMutexLock lock(&instruction_buffers_vector_mutex_); + + instruction_buffers_vector_.push_back(std::move(instruction_buffers)); + VLOG(10) << "Returned instruction buffers back to executable reference"; +} + +util::Status ExecutableReference::PrepareParameters() { + // If parameters are not in on-chip DRAM or they have already been loaded + // there, nothing else to do here. + if (!parameters_.IsDramType() || parameters_loaded_) { + return util::OkStatus(); + } + + ASSIGN_OR_RETURN(auto dram_buffer, parameters_.GetDramBuffer()); + // TODO Get rid of this const_cast. + RETURN_IF_ERROR(dram_buffer->ReadFrom(const_cast( + reinterpret_cast(executable_->parameters()->data())))); + + parameters_loaded_ = true; + VLOG(2) << "Parameters were loaded on DRAM."; + + return util::OkStatus(); +} + +void ExecutableReference::ResetParametersLoaded() { + if (parameters_.IsDramType()) { + parameters_loaded_ = false; + } +} + +util::Status ExecutableReference::SetMappedParameters( + MappedDeviceBuffer&& mapped_parameters) { + if (parameters_mapped_) { + RETURN_IF_ERROR(mapped_parameters.Unmap()); + return util::FailedPreconditionError("Parameters are already mapped."); + } + + mapped_parameters_ = std::move(mapped_parameters); + parameters_mapped_ = true; + + return util::OkStatus(); +} + +util::Status ExecutableReference::UnmapParameters() { + if (!parameters_mapped_) { + return util::FailedPreconditionError( + "Parameters are not currently mapped."); + } + + RETURN_IF_ERROR(mapped_parameters_.Unmap()); + parameters_mapped_ = false; + + return util::OkStatus(); +} + +PackageReference::PackageReference(const Buffer& package_buffer, + const Executable* standalone_executable, + Allocator* allocator, + DramAllocator* dram_allocator, + PackageVerifier* verifier) + : package_buffer_(package_buffer), + package_(flatbuffers::GetRoot(package_buffer.ptr())), + verifier_(verifier), + standalone_reference_(new driver::ExecutableReference( + standalone_executable, allocator, dram_allocator, this)) {} + +PackageReference::PackageReference( + const Buffer& package_buffer, + const Executable* parameter_caching_executable, + const Executable* inference_executable, Allocator* allocator, + DramAllocator* dram_allocator, PackageVerifier* verifier) + : package_buffer_(package_buffer), + package_(flatbuffers::GetRoot(package_buffer.ptr())), + verifier_(verifier), + parameter_caching_reference_(new driver::ExecutableReference( + parameter_caching_executable, allocator, dram_allocator, this)), + inference_reference_(new driver::ExecutableReference( + inference_executable, allocator, dram_allocator, this)) {} + +std::vector +PackageReference::AllExecutableReferences() const { + std::vector all_references; + if (standalone_reference_ != nullptr) { + all_references.push_back(standalone_reference_.get()); + } + if (parameter_caching_reference_ != nullptr) { + all_references.push_back(parameter_caching_reference_.get()); + } + if (inference_reference_ != nullptr) { + all_references.push_back(inference_reference_.get()); + } + return all_references; +} + +util::Status PackageReference::UnmapParameters() { + util::Status status; + + for (ExecutableReference* executable_ref : AllExecutableReferences()) { + status.Update(executable_ref->UnmapParameters()); + } + + return status; +} + +util::StatusOr PackageReference::ParametersMapped() const { + auto all_executable_refs = AllExecutableReferences(); + if (all_executable_refs.empty()) { + return util::FailedPreconditionError( + "No executable references were found in the package reference."); + } + bool parameters_mapped = all_executable_refs.front()->ParametersMapped(); + + for (auto* executable_ref : all_executable_refs) { + if (executable_ref->ParametersMapped() != parameters_mapped) { + return util::InternalError( + "Inconsistent parameter mapping status across executables in the " + "same package."); + } + } + + return parameters_mapped; +} + +bool PackageReference::NeedsDram() const { + auto all_executable_refs = AllExecutableReferences(); + + for (auto* executable_ref : all_executable_refs) { + if (executable_ref->NeedsDram()) { + return true; + } + } + + return false; +} + +util::Status PackageReference::SetLatencyTolerance(int64 latency_tolerance_ms) { + latency_tolerance_ms_ = latency_tolerance_ms; + return util::OkStatus(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/package_registry.h b/driver/package_registry.h new file mode 100644 index 0000000..05941cb --- /dev/null +++ b/driver/package_registry.h @@ -0,0 +1,807 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_PACKAGE_REGISTRY_H_ +#define DARWINN_DRIVER_PACKAGE_REGISTRY_H_ + +#include +#include // NOLINT +#include +#include +#include + +#include "api/buffer.h" +#include "api/chip.h" +#include "api/driver_options_generated.h" +#include "api/execution_context_interface.h" +#include "api/layer_information.h" +#include "api/package_reference.h" +#include "driver/aligned_allocator.h" +#include "driver/device_buffer_mapper.h" +#include "driver/instruction_buffers.h" +#include "driver/memory/dram_allocator.h" +#include "driver/package_verifier.h" +#include "executable/executable_generated.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +class PackageReference; + +// Holds the input and output layer info from an executable. +class ExecutableLayersInfo { + public: + ExecutableLayersInfo(const Executable* executable); + ~ExecutableLayersInfo() = default; + + // This class is neither copyable nor movable. + ExecutableLayersInfo(const ExecutableLayersInfo&) = delete; + ExecutableLayersInfo& operator=(const ExecutableLayersInfo&) = delete; + + // Returns the index of input layer with given name. + util::StatusOr InputIndex(const std::string& name) const; + + // Returns the index of output layer with given name. + util::StatusOr OutputIndex(const std::string& name) const; + + // Returns number of input layers. + int NumInputLayers() const { return inputs_.size(); } + + // Returns number of output layers. + int NumOutputLayers() const { return outputs_.size(); } + + // Returns list of input layer names. + const std::vector& InputLayerNames() const { + return input_layer_names_; + } + + // Returns list of output layer names. + const std::vector& OutputLayerNames() const { + return output_layer_names_; + } + + // Returns information on given input layer. Returns nullptr if index is out + // of bounds. + const api::InputLayerInformation* InputLayer(int index) const; + + // Returns information on given output layer. Returns nullptr if index is out + // of bounds. + const api::OutputLayerInformation* OutputLayer(int index) const; + + // Returns information on given input layer. + util::StatusOr InputLayer( + const std::string& layer_name) const; + + // Returns information on given output layer. + util::StatusOr OutputLayer( + const std::string& layer_name) const; + + // Returns the expected byte size of activations for given input layer index. + int InputLayerSizeBytes(int index) const { + CHECK(InputLayer(index) != nullptr); + return InputLayer(index)->ActualSizeBytes(); + } + + // Returns the expected byte size of activations for given input layer index. + // This is post-padding, if any. + // TODO Remove this method. + int InputLayerPaddedSizeBytes(int index) const { + CHECK(InputLayer(index) != nullptr); + return InputLayer(index)->PaddedSizeBytes(); + } + + // Returns the expected byte size of activations for given output layer index. + int OutputLayerSizeBytes(int index) const { + CHECK(OutputLayer(index) != nullptr); + return OutputLayer(index)->ActualSizeBytes(); + } + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + int InputLayerSize(int index) const { + auto layer = InputLayer(index); + CHECK(layer != nullptr); + return layer->y_dim() * layer->x_dim() * layer->z_dim() * + layer->execution_count_per_inference(); + } + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + int OutputLayerSize(int index) const { + auto layer = OutputLayer(index); + CHECK(layer != nullptr); + return layer->y_dim() * layer->x_dim() * layer->z_dim() * + layer->execution_count_per_inference(); + } + + // Returns the expected size of activations for given input layer. + // Prefer index based APIs for performance. + util::StatusOr InputLayerSizeBytes(const std::string& name) const; + + // Returns the expected size of activations for given input layer including + // padding bytes (if any). + // Prefer index based APIs for performance. + // TODO Remove this method. + util::StatusOr InputLayerPaddedSizeBytes(const std::string& name) const; + + // Returns the expected size of activations for given output layer. + // Prefer index based APIs for performance. + util::StatusOr OutputLayerSizeBytes(const std::string& name) const; + + // Returns name for given input layer index. + std::string InputLayerName(int index) const { + CHECK(InputLayer(index) != nullptr); + return InputLayer(index)->name(); + } + + // Returns name for given output layer index. + std::string OutputLayerName(int index) const { + CHECK(OutputLayer(index) != nullptr); + return OutputLayer(index)->name(); + } + + // Returns if on-chip DRAM is needed in either input or output layers. + bool NeedsDramInLayers() const { return needs_dram_in_layers_; } + + private: + // Vector with list of input layer names. + std::vector input_layer_names_; + + // Vector with list of output layer names. + std::vector output_layer_names_; + + // Vector with detailed input layer information. + std::vector inputs_; + + // Vector with detailed outpu layer information. + std::vector outputs_; + + // Maps input layer names to indices. + std::unordered_map input_map_; + + // Maps output layer names to indices. + std::unordered_map output_map_; + + // Specifies if this executable needs on-chip DRAM for input or output layers. + bool needs_dram_in_layers_ = false; +}; + +// Reference to a single executable. +class ExecutableReference { + public: + // This class is neither copyable nor movable. + ExecutableReference(const ExecutableReference&) = delete; + ExecutableReference& operator=(const ExecutableReference&) = delete; + + // Returns the index of input layer with given name. + util::StatusOr InputIndex(const std::string& name) const { + return executable_layers_info_->InputIndex(name); + } + + // Returns the index of output layer with given name. + util::StatusOr OutputIndex(const std::string& name) const { + return executable_layers_info_->OutputIndex(name); + } + + // Returns number of input layers. + int NumInputLayers() const { + return executable_layers_info_->NumInputLayers(); + } + + // Returns number of output layers. + int NumOutputLayers() const { + return executable_layers_info_->NumOutputLayers(); + } + + // Returns list of input layer names. + const std::vector& InputLayerNames() const { + return executable_layers_info_->InputLayerNames(); + } + + // Returns list of output layer names. + const std::vector& OutputLayerNames() const { + return executable_layers_info_->OutputLayerNames(); + } + + // Returns information on given input layer. Returns nullptr if index is out + // of bounds. + const api::InputLayerInformation* InputLayer(int index) const { + return executable_layers_info_->InputLayer(index); + } + + // Returns information on given output layer. Returns nullptr if index is out + // of bounds. + const api::OutputLayerInformation* OutputLayer(int index) const { + return executable_layers_info_->OutputLayer(index); + } + + // Returns information on given input layer. + util::StatusOr InputLayer( + const std::string& layer_name) const { + return executable_layers_info_->InputLayer(layer_name); + } + + // Returns information on given output layer. + util::StatusOr OutputLayer( + const std::string& layer_name) const { + return executable_layers_info_->OutputLayer(layer_name); + } + + // Returns the expected byte size of activations for given input layer index. + int InputLayerSizeBytes(int index) const { + return executable_layers_info_->InputLayerSizeBytes(index); + } + + // Returns the expected byte size of activations for given input layer index. + // This is post-padding, if any. + // TODO Remove this method. + int InputLayerPaddedSizeBytes(int index) const { + return executable_layers_info_->InputLayerPaddedSizeBytes(index); + } + + // Returns the expected byte size of activations for given output layer index. + int OutputLayerSizeBytes(int index) const { + return executable_layers_info_->OutputLayerSizeBytes(index); + } + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + int InputLayerSize(int index) const { + return executable_layers_info_->InputLayerSize(index); + } + + // Returns the expected size (in value count) of activations for given input + // layer index. This is pre-padding, if any. + int OutputLayerSize(int index) const { + return executable_layers_info_->OutputLayerSize(index); + } + + // Returns the expected size of activations for given input layer. + // Prefer index based APIs for performance. + util::StatusOr InputLayerSizeBytes(const std::string& name) const { + return executable_layers_info_->InputLayerSizeBytes(name); + } + + // Returns the expected size of activations for given input layer including + // padding bytes (if any). + // Prefer index based APIs for performance. + // TODO Remove this method. + util::StatusOr InputLayerPaddedSizeBytes(const std::string& name) const { + return executable_layers_info_->InputLayerPaddedSizeBytes(name); + } + + // Returns the expected size of activations for given output layer. + // Prefer index based APIs for performance. + util::StatusOr OutputLayerSizeBytes(const std::string& name) const { + return executable_layers_info_->OutputLayerSizeBytes(name); + } + + // Returns name for given input layer index. + std::string InputLayerName(int index) const { + return executable_layers_info_->InputLayerName(index); + } + + // Returns name for given output layer index. + std::string OutputLayerName(int index) const { + return executable_layers_info_->OutputLayerName(index); + } + + // Returns batch size. + int BatchSize() const { return executable_->batch_size(); } + + // Executable + const darwinn::Executable& executable() const { return *executable_; } + + // Memory aligned copy of the parameters. + const Buffer& parameters() const { return parameters_; } + + // Sets mapped parameters. + // Move-only. The given mapped_parameter will be unmapped during destruction + // time, so we cannot allow copy-construction, to avoid doubly unmapping a + // Device Buffer. + util::Status SetMappedParameters(MappedDeviceBuffer&& mapped_parameters); + + // Unmaps the parameters buffer from the device. + util::Status UnmapParameters(); + + // Returns true if the parameters buffer is already mapped to the device. + bool ParametersMapped() const { return parameters_mapped_; } + + // Returns the device mapped buffer for the parameters in this executable. + const DeviceBuffer& GetParameterDeviceBuffer() const { + return mapped_parameters_.device_buffer(); + } + + // Scratch buffer, if the executable requires scratch space. If not, then the + // buffer will be invalid. + Buffer scratch() const { return scratch_; } + + // Validates that the given input buffer is compatible with the executable. + util::Status ValidateInput(const std::string& input_name, + const Buffer& input) const; + + // Validates that the given output buffer is compatible with the executable. + util::Status ValidateOutput(const std::string& output_name, + const Buffer& output) const; + + // Returns the parameter-caching token which is unique across models that are + // compiled together and can cache their parameters on TPU SRAM at the same + // time. If 0, it means this executable's parameters cannot safely co-exist + // with those of others. Please note that tokens are not limited to parameter + // cached models. We could have a stand-alone compiled model that still has + // a token to ensure us it will not overwite cached parameters of other + // models. + uint64 ParameterCachingToken() const { + return executable().parameter_caching_token(); + } + + // Returns the estimated runtime in cycles for this executable. + int64 EstimatedCycles() const { + return executable().estimated_cycles_64bit(); + } + + // Reuses or creates instruction buffers. + std::unique_ptr GetInstructionBuffers( + Allocator* allocator); + + // Returns instruction buffers back to executable reference. + // TODO: Add pool size limit. This currently doesn't have size + // limit, and if there are many requests happened at the same time, we might + // increase the total memory footprint. Notice that this won't increase + // the memory peak size but will hold them longer. If this becomes an issue + // we should investigate what's the correct size limit. + void ReturnInstructionBuffers( + std::unique_ptr instruction_buffers); + + // Makes sure parameters are present in host or TPU DRAM to be used in an + // inference. This method is not thread-safe. + util::Status PrepareParameters(); + + // Resets any assumptions about parameters being loaded on TPU DRAM. This + // method is not thread-safe. + void ResetParametersLoaded(); + + // Specifies if this executable needs on-chip DRAM to execute. + bool NeedsDram() const { return needs_dram_; } + + // Returns the amount of narrow memory (in bytes) used by each tile in this + // executable. + int64 UsedNarrowMemoryBytesPerTile() const { + return executable_->used_narrow_memory_bytes_per_tile(); + } + + const PackageReference& GetPackageReference() const { + return *package_reference_; + } + + private: + friend class PackageReference; + + // Allow constructions through the friend classes only. + ExecutableReference(const Executable* executable, Allocator* allocator, + DramAllocator* dram_allocator, PackageReference* pkg_ref); + + // Memory aligned copy of parameters. + Buffer parameters_; + + // Mapped parameters. + MappedDeviceBuffer mapped_parameters_; + + // Scratch buffer, if the executable requires scratch space. If not, then the + // buffer will be invalid. + Buffer scratch_; + + // Holds the backing executable. + const Executable* executable_; + + // Holds the information on input and output layers. + std::unique_ptr executable_layers_info_; + + mutable std::mutex instruction_buffers_vector_mutex_; + std::vector> instruction_buffers_vector_ + GUARDED_BY(instruction_buffers_vector_mutex_); + + // Specifies if parameters of this executable are mapped to the device. + bool parameters_mapped_ = false; + + // Specifies if parameters are already loaded to on-chip DRAM. + bool parameters_loaded_ = false; + + // Specifies if this executable needs on-chip DRAM to execute. + // The DRAM might be needed in input and output layers, parameters, or scratch + // memory. + bool needs_dram_ = false; + + // Pointer to package reference that contains this object. This object does + // not own package_reference_. + PackageReference *package_reference_; +}; + +// Contains an executable package. +class PackageReference : public api::PackageReference { + public: + // This class is neither copyable nor movable. + PackageReference(const PackageReference&) = delete; + PackageReference& operator=(const PackageReference&) = delete; + + // Verifies the digital signature in the package. + util::Status VerifySignature() const override { + return verifier_->VerifySignature(package_buffer_.ptr()); + } + + // The following methods just call their counterpart in + // MainExecutableReference(). + util::StatusOr InputIndex(const std::string& name) const override { + return MainExecutableReference()->InputIndex(name); + } + + util::StatusOr OutputIndex(const std::string& name) const override { + return MainExecutableReference()->OutputIndex(name); + } + + int NumInputLayers() const override { + return MainExecutableReference()->NumInputLayers(); + } + + int NumOutputLayers() const override { + return MainExecutableReference()->NumOutputLayers(); + } + + const std::vector& InputLayerNames() const override { + return MainExecutableReference()->InputLayerNames(); + } + + const std::vector& OutputLayerNames() const override { + return MainExecutableReference()->OutputLayerNames(); + } + + const api::InputLayerInformation* InputLayer(int index) const override { + return MainExecutableReference()->InputLayer(index); + } + + const api::OutputLayerInformation* OutputLayer(int index) const override { + return MainExecutableReference()->OutputLayer(index); + } + + util::StatusOr InputLayer( + const std::string& layer_name) const override { + return MainExecutableReference()->InputLayer(layer_name); + } + + util::StatusOr OutputLayer( + const std::string& layer_name) const override { + return MainExecutableReference()->OutputLayer(layer_name); + } + + int InputLayerSizeBytes(int index) const override { + return MainExecutableReference()->InputLayerSizeBytes(index); + } + + // TODO Remove this method. + int InputLayerPaddedSizeBytes(int index) const override { + return MainExecutableReference()->InputLayerPaddedSizeBytes(index); + } + + int OutputLayerSizeBytes(int index) const override { + return MainExecutableReference()->OutputLayerSizeBytes(index); + } + + int InputLayerSize(int index) const override { + return MainExecutableReference()->InputLayerSize(index); + } + + int OutputLayerSize(int index) const override { + return MainExecutableReference()->OutputLayerSize(index); + } + + util::StatusOr InputLayerSizeBytes( + const std::string& name) const override { + return MainExecutableReference()->InputLayerSizeBytes(name); + } + + // TODO Remove this method. + util::StatusOr InputLayerPaddedSizeBytes( + const std::string& name) const override { + return MainExecutableReference()->InputLayerPaddedSizeBytes(name); + } + + util::StatusOr OutputLayerSizeBytes( + const std::string& name) const override { + return MainExecutableReference()->OutputLayerSizeBytes(name); + } + + std::string InputLayerName(int index) const override { + return MainExecutableReference()->InputLayerName(index); + } + + std::string OutputLayerName(int index) const override { + return MainExecutableReference()->OutputLayerName(index); + } + + int BatchSize() const override { + return MainExecutableReference()->BatchSize(); + } + + util::Status SetLatencyTolerance(int64 latency_tolerance_ms) override; + + // Returns a vector of all executable references in this package. + std::vector AllExecutableReferences() const; + + // Returns the main executable reference to refer to for read-only information + // (e.g. number of layers). + const driver::ExecutableReference* MainExecutableReference() const { + return standalone_reference_ ? standalone_reference_.get() + : inference_reference_.get(); + } + + // Returns true if this package is parameter-caching-enabled. + bool ParameterCachingEnabled() const { + return parameter_caching_reference_ != nullptr; + } + + // Returns the inference executable reference in this package. You can make + // sure such reference exists by calling ParameterCachingEnabled method. + const driver::ExecutableReference* InferenceExecutableReference() const { + return inference_reference_.get(); + } + + // Returns the parameter-caching executable reference in this package. You can + // make sure such reference exists by calling ParameterCachingEnabled method. + const driver::ExecutableReference* ParameterCachingExecutableReference() + const { + return parameter_caching_reference_.get(); + } + + // Returns true if parameters of this package are mapped to the device. + util::StatusOr ParametersMapped() const; + + // Specifies if this package needs on-chip DRAM to execute. + bool NeedsDram() const; + + // Sets the execution context interface. This class owns the execution + // context. + void SetExecutionContextInterface( + std::unique_ptr + execution_context_interface) override { + execution_context_interface_ = std::move(execution_context_interface); + } + + std::string ModelIdentifier() const override { + return flatbuffers::GetString(package_->model_identifier()); + } + + // Returns the stored execution context interface. This class still owns the + // object. + api::ExecutionContextInterface* GetExecutionContextInterface() const { + return execution_context_interface_.get(); + } + + // Returns the amount of time in milliseconds that this package can tolerate + // for an inference to run (including parameter-caching). If batched, then for + // all batch elements to complete. + int64 LatencyToleranceMs() const { return latency_tolerance_ms_; } + + private: + friend class PackageRegistry; + + // Allow constructions through the ExecutableRegistry class only. + // + // The current implementation allows either a stand-alone executable or + // parameter-caching + inference. + + // Constructor for stand-alone executable. + PackageReference(const Buffer& package_buffer, + const Executable* standalone_executable, + Allocator* allocator, DramAllocator* dram_allocator, + PackageVerifier* verifier); + + // Constructor for a parameter cached package. + PackageReference(const Buffer& package_buffer, + const Executable* parameter_caching_executable, + const Executable* inference_executable, Allocator* allocator, + DramAllocator* dram_allocator, PackageVerifier* verifier); + + // Unmaps parameters of all executables in this package. + util::Status UnmapParameters(); + + // Buffer backing the package buffer. + Buffer package_buffer_; + + // The flatbuffer representation of the package we are wrapping. + const Package* package_; + + // A ExecutableVerifier for checking digital signatures on executable + // packages. + const PackageVerifier* const verifier_; + + // The ExecutableReference for the parameter-caching executable. + std::unique_ptr parameter_caching_reference_; + + // The Executable reference for the inference executable. + std::unique_ptr inference_reference_; + + // The Executable reference for the stand-alone executable. + std::unique_ptr standalone_reference_; + + // The execution context for the package reference. + std::unique_ptr execution_context_interface_; + + // Maximum number of milliseconds this package can tolerate for an inference + // request to run. + int64 latency_tolerance_ms_ = -1; +}; + +// A registry for executable files. Most methods do not have built-in thread- +// safety and rely on base driver class to ensure that. Some methods require +// thread-safety with respect to their call-sites in base driver and that is +// implemented in this class. +class PackageRegistry { + public: + PackageRegistry(api::Chip chip, + std::unique_ptr executable_verifier, + DramAllocator* dram_allocator); + + // Constructs an ExecutableRegistry that will check to make sure all + // registered executables are for the correct chip. However, it does not + // support DRAM nor signature verification. Please prefer the first + // constructor. + explicit PackageRegistry(api::Chip chip); + + // Constructs an ExecutableRegistry that does not check if the executables + // being registered are for the correct device. Please prefer the first + // constructor. + PackageRegistry(); + + ~PackageRegistry() = default; + + // This class is neither copyable nor movable. + PackageRegistry(const PackageRegistry&) = delete; + PackageRegistry& operator=(const PackageRegistry&) = delete; + + // Returns the main executable layer info from the given executable binary. + static util::StatusOr> + GetMainExecutableLayersInfoFromBinary(const char* executable_content, + size_t length); + + // Registers a serialized executable binary. Once the executable is + // registered, driver has its own copy of it so there would be no need to keep + // the executable_content in memory. + util::StatusOr RegisterSerialized( + const std::string& executable_content); + util::StatusOr RegisterSerialized( + const char* executable_content, size_t length); + + // Same as above, but the executable is read from the given file. + util::StatusOr RegisterFile( + const std::string& executable_filename); + + // Unregisters an executable. Invokes the callback to unmap the parameter. + util::Status Unregister(const api::PackageReference* package_reference); + + // Unregisteres all registered executables. + util::Status UnregisterAll() LOCKS_EXCLUDED(registrations_mutex_); + + // Unmaps all parameters in all registered packages. + util::Status UnmapAllParameters() LOCKS_EXCLUDED(registrations_mutex_); + + // Returns the number of registered executables. + int GetRegistrySize() const LOCKS_EXCLUDED(registrations_mutex_) { + StdMutexLock registration_lock(®istrations_mutex_); + return registrations_.size(); + } + + // Returns all the package references registered here. + std::vector GetAllRegistrations() const + LOCKS_EXCLUDED(registrations_mutex_); + + // Resets any assumptions about parameters of any current registrations being + // loaded on TPU DRAM. + void ResetParametersLoaded() LOCKS_EXCLUDED(registrations_mutex_); + + private: + // Returns the main executable from the executable map. + // Returns error if failed to find main executable or had unexpected + // executable combinations. + static util::StatusOr GetMainExecutableFromExecutableMap( + std::unordered_map); + + // Returns the parameter caching executable from the executable map. + // Returns nullptr if no parameter caching executable could be found in the + // map. + // Returns error if had unexpected executable combinations. + static util::StatusOr GetPCExecutableFromExecutableMap( + std::unordered_map); + + // Registers an executable package binary. + util::StatusOr RegisterPackage( + const Buffer& package_buffer); + + // Takes in a multi-executable and returns a map of each executable type to + // its executable pointer. It will return an error in any illegal combination. + // Legitimate combinations are: + // + // 1. A single executable (no matter what type). + // 2. 1 parameter-caching and 1 inference executable. + // 3. 1 parameter-caching, 1 inference and 1 stand-alone executable. + static util::StatusOr> + ExtractExecutables(const MultiExecutable& multi_executable); + + // Takes in executable binary content and returns a map of each executable + // type to its executable pointer. + // The inputs are pointers to the executable binary, the length of the binary, + // and the targeted chip to run this executables on. + static util::StatusOr> + GetExecutablesFromBinary(const char* executable_content, size_t length); + + // Fetches an Executable from its serialized version and performs some + // verification checks (does not include signature verification). + static util::StatusOr FetchAndVerifyExecutable( + const char* executable_serialized, size_t length); + + // Checks if the chip config specified in the executable binary matches the + // one registered to this package registry. + util::Status VerifyExecutableMatchesChip(const Executable*) const; + + const api::PackageReference* SetRegistrations( + std::unique_ptr api_package_ref) + LOCKS_EXCLUDED(registrations_mutex_); + + // Allocator. + AlignedAllocator allocator_; + + // A pointer to the entity responsible for allocating on-chip DRAM buffers. + DramAllocator* dram_allocator_; + + // A mutex to synchronize access to registrations_. + mutable std::mutex registrations_mutex_; + + // Tracks registrations. + // Uses a map instead of a set, since looking up an std::set of unique_ptr is + // tricky. + std::unordered_map> + registrations_ GUARDED_BY(registrations_mutex_); + + // Executables will be checked to ensure they were compiled for this chip. + // Can be kUnknown to disable checking. + const api::Chip chip_; + + // A verifier for checking digital signatures on executable packages. + std::unique_ptr verifier_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + + +namespace std { + template<> + struct hash<::platforms::darwinn::ExecutableType> { + typedef ::platforms::darwinn::ExecutableType argument_type; + typedef std::underlying_type::type underlying_type; + typedef std::hash::result_type result_type; + result_type operator()(const argument_type& arg) const { + std::hash hasher; + return hasher(static_cast(arg)); + } + }; +} + +#endif // DARWINN_DRIVER_PACKAGE_REGISTRY_H_ diff --git a/driver/package_verifier.cc b/driver/package_verifier.cc new file mode 100644 index 0000000..247acb7 --- /dev/null +++ b/driver/package_verifier.cc @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/package_verifier.h" + +#include + +#include "executable/executable_generated.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status NoopPackageVerifier::VerifySignature(const void*) const { + return util::FailedPreconditionError( + "No verifier was created yet verification was requested."); +} + +util::StatusOr> MakeExecutableVerifier( + const std::string& public_key_path) { + + return {gtl::MakeUnique()}; +} + +util::StatusOr> MakeExecutableVerifierFromFile( + const std::string& public_key_path) { + + return {gtl::MakeUnique()}; +} +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/package_verifier.h b/driver/package_verifier.h new file mode 100644 index 0000000..a78b67b --- /dev/null +++ b/driver/package_verifier.h @@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_PACKAGE_VERIFIER_H_ +#define DARWINN_DRIVER_PACKAGE_VERIFIER_H_ + +#include +#include + +#include "port/defs.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/openssl.h" +#include "port/statusor.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// ExecutableVerifier is a class to verify executable packages using digital +// signatures. +class PackageVerifier { + public: + virtual ~PackageVerifier() = default; + + // Verifies the executable package provided its buffer. + virtual util::Status VerifySignature(const void* package_buffer) const = 0; +}; + +// A noop implementation of ExecutableVerifier that errors out on all calls. +class NoopPackageVerifier : public PackageVerifier { + public: + NoopPackageVerifier() = default; + ~NoopPackageVerifier() override = default; + util::Status VerifySignature(const void* package_buffer) const override; +}; + +// Makes an ExecutableVerifier provided a public key. If the key is empty a noop +// verifier will be returned that errors on Verify. +util::StatusOr> MakeExecutableVerifier( + const std::string& public_key); + +// Makes an ExecutableVerifier provided a file path to the public key. If the +// path is empty a noop verifier will be returned that errors on Verify. +util::StatusOr> MakeExecutableVerifierFromFile( + const std::string& public_key_path); + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_PACKAGE_VERIFIER_H_ diff --git a/driver/real_time_dma_scheduler.cc b/driver/real_time_dma_scheduler.cc new file mode 100644 index 0000000..d0cca23 --- /dev/null +++ b/driver/real_time_dma_scheduler.cc @@ -0,0 +1,231 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/real_time_dma_scheduler.h" + +#include "absl/strings/str_format.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status RealTimeDmaScheduler::Open() { return backing_scheduler_->Open(); } + +util::Status RealTimeDmaScheduler::Close(api::Driver::ClosingMode mode) { + ResetTiming(); + return backing_scheduler_->Close(mode); +} + +util::Status RealTimeDmaScheduler::Submit(std::shared_ptr request) { + StdMutexLock lock(&mutex_); + if (!real_time_mode_) { + return backing_scheduler_->Submit(request); + } + + const auto* executable_ref = &request->executable_reference(); + const int64 time_now_us = time_stamper_->GetTimeMicroSeconds(); + + // TODO: We allocate Timing information every time a new + // executable requesting inference, but we need to decide a good way and when + // to cleanup the timing information. + auto& cur_timing = inference_timings_[executable_ref]; + + // Updating the arrival time regardless whether we can schedule it or not. + cur_timing.last_arrival_time_us = time_now_us; + + // A normal process w/o max execution time cannot be scheduled at this time. + // TODO: move such processes into a queue that would be submitted + // when leaving real-time mode. + if (cur_timing.max_execution_time_ms == 0) { + if (cur_timing.fps != 0) { + // FPS > 0, and 0 MET: ill-formed Timing information. + return util::InvalidArgumentError( + "Unable to submit under real-time mode. " + "Ill-formed timing information: FPS > 0 but MET == 0."); + } + return util::DeadlineExceededError( + "Normal process without MET cannot be scheduled in real-time mode."); + } + + // Can we submit? Computing T_deadline using the basic algorithm. + int64 deadline_us = INT64_MAX; + time_booked_us_ = std::max(time_now_us, time_booked_us_); + + for (const auto& inference_timing : inference_timings_) { + if (inference_timing.first == executable_ref) { + continue; // No need to count self. + } + const TimingInternal& timing = inference_timing.second; + if (!timing.HasRealTimeRequirements()) { + // For normal models, assuming no ETA of next request. + continue; + } + if (timing.last_arrival_time_us == 0) { + continue; // The model has no past inferences. Skip. + } + + ASSIGN_OR_RETURN(const int64 frame_time_us, timing.frame_time_us()); + const int64 time_next_us = + timing.last_arrival_time_us + frame_time_us + + std::min(timing.tolerance_us(), + frame_time_us - timing.max_execution_time_us()); + // If we missed two frames already, assume it's no longer arriving. + // TODO: verify if that assumption holds. + if (time_next_us + 2 * frame_time_us < time_now_us) { + continue; + } + + deadline_us = std::min(deadline_us, time_next_us); + } + + if (deadline_us > time_booked_us_ + cur_timing.max_execution_time_us()) { + // Ok to schedule; add this to booked time. + time_booked_us_ += cur_timing.max_execution_time_us(); + return backing_scheduler_->Submit(request); + } else { + return util::DeadlineExceededError( + "The request cannot be scheduled within given time budget."); + } +} + +void RealTimeDmaScheduler::SetRealtimeMode(bool on) { + StdMutexLock lock(&mutex_); + real_time_mode_ = on; +} + +int64 RealTimeDmaScheduler::GetLastArrivalTime( + const ExecutableReference* executable) const { + StdMutexLock lock(&mutex_); + + auto found = inference_timings_.find(executable); + if (found == inference_timings_.end()) { + return 0; + } else { + return found->second.last_arrival_time_us; + } +} + +util::Status RealTimeDmaScheduler::RemoveExecutableTiming( + const ExecutableReference* executable) { + if (executable == nullptr) { + return util::InvalidArgumentError("Null executable refernce."); + } + + StdMutexLock lock(&mutex_); + // It is not an error if we don't have timing for an executable: for older + // binaries they might not have the estimated clock cycle information and thus + // may or may not have been registered here. + inference_timings_.erase(executable); + return util::OkStatus(); +} + +util::Status RealTimeDmaScheduler::SetExecutableTiming( + const ExecutableReference* executable, const api::Timing& timing) { + VLOG(3) << "RealTimeDmaScheduler: received timing setting: " << timing.Dump(); + if (!executable) { + return util::InvalidArgumentError("Null executable refernce."); + } + + api::Timing candidate_timing = timing; + StdMutexLock lock(&mutex_); + auto timing_entry = inference_timings_.find(executable); + if (timing_entry != inference_timings_.end()) { + // If we have initial timing information, we could allow incremental + // updates on individual timing fields. + TimingInternal& existing_timing = timing_entry->second; + if (candidate_timing.fps < 0) { + candidate_timing.fps = existing_timing.fps; + } + if (candidate_timing.max_execution_time_ms < 0) { + candidate_timing.max_execution_time_ms = + existing_timing.max_execution_time_ms; + } + if (candidate_timing.tolerance_ms < 0) { + candidate_timing.tolerance_ms = existing_timing.tolerance_ms; + } + } else { + // Setting this executable's initial timing. + // FPS can be zero, which means a normal process w/o expected arrival rate. + // However, they cannot be negative. + if (candidate_timing.fps < 0 || + candidate_timing.max_execution_time_ms < 0 || + candidate_timing.tolerance_ms < 0) { + return util::InvalidArgumentError("Bad timing value(s)."); + } + } + + const TimingInternal timing_internal(candidate_timing); + + // Checks applicable to models with service guarantee requirements. + if (timing_internal.HasRealTimeRequirements()) { + ASSIGN_OR_RETURN(const int64 frame_time_us, + timing_internal.frame_time_us()); + // If FPS > 0, MET has to be > 0. + if (timing_internal.max_execution_time_ms == 0) { + return util::InvalidArgumentError(StringPrintf( + "Invalid max execution time: %dms.", timing.max_execution_time_ms)); + } + if (frame_time_us < timing_internal.max_execution_time_us()) { + return util::InvalidArgumentError(absl::StrFormat( + "Max execution time (%lldus) exceeds frame time (%lldus).", + timing_internal.max_execution_time_us(), frame_time_us)); + } + // Tolerance cannot be greater than frame time - MET (i.e. cannot push + // expected finish time beyond one frame). + if (timing_internal.tolerance_us() > + (frame_time_us - timing_internal.max_execution_time_us())) { + return util::InvalidArgumentError(absl::StrFormat( + "Invalid tolerance (%lldus). Needs to be less than %lldus to fit in " + "one frame.", + timing_internal.tolerance_us(), + frame_time_us - timing_internal.max_execution_time_us())); + } + } + + inference_timings_[executable] = timing_internal; + VLOG(3) << "RealTimeDmaScheduler: applied timing setting: " + << timing_internal.Dump(); + return util::OkStatus(); +} + +util::StatusOr RealTimeDmaScheduler::GetExecutableTiming( + const ExecutableReference* executable) const { + StdMutexLock lock(&mutex_); + + if (!executable) { + return util::InvalidArgumentError("Null executable refernce."); + } + + auto found = inference_timings_.find(executable); + if (found == inference_timings_.end()) { + return util::NotFoundError( + "Given executable reference has no associated timing information."); + } else { + return found->second; + } +} + +util::Status RealTimeDmaScheduler::NotifyRequestCompletion() { + // TODO: update MET dynamically. + return backing_scheduler_->NotifyRequestCompletion(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/real_time_dma_scheduler.h b/driver/real_time_dma_scheduler.h new file mode 100644 index 0000000..7d0ece2 --- /dev/null +++ b/driver/real_time_dma_scheduler.h @@ -0,0 +1,201 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_REAL_TIME_DMA_SCHEDULER_H_ +#define DARWINN_DRIVER_REAL_TIME_DMA_SCHEDULER_H_ + +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include + +#include "api/driver.h" +#include "api/package_reference.h" +#include "api/timing.h" +#include "api/watchdog.h" +#include "driver/dma_info.h" +#include "driver/dma_scheduler.h" +#include "driver/package_registry.h" +#include "driver/single_queue_dma_scheduler.h" +#include "driver/time_stamper/time_stamper.h" +#include "driver/tpu_request.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/std_mutex_lock.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Manages DMA with best-effort QoS. Works as a gating function to the +// underlying single queue DMA. +class RealTimeDmaScheduler : public DmaScheduler { + public: + RealTimeDmaScheduler() = delete; + RealTimeDmaScheduler(std::unique_ptr watchdog, + std::unique_ptr time_stamper) + : backing_scheduler_( + gtl::MakeUnique(std::move(watchdog))), + time_stamper_(std::move(time_stamper)) {} + ~RealTimeDmaScheduler() override = default; + + // Implements DmaScheduler interfaces. + util::Status Open() override; + util::Status Close(api::Driver::ClosingMode mode) override; + + // Submits a request. Forwards everything to the backing DMA scheduler (Single + // queue DMA) in normal mode; accepts the request if real-time constraint can + // be met, i.e. not impacting other service guarantees, otherwise rejects the + // request in real-time mode. For details, please refer to + // go/darwinn-qos-design. + util::Status Submit(std::shared_ptr request) override + LOCKS_EXCLUDED(mutex_); + + // DmaScheduler interface. Simply forward them to the backing scheduler. + util::Status NotifyRequestCompletion() override; + util::Status CancelPendingRequests() override { + return backing_scheduler_->CancelPendingRequests(); + } + util::Status WaitActiveRequests() override { + return backing_scheduler_->WaitActiveRequests(); + } + // Implements lower level DMA routines. They should be directly forwarded to + // the backing driver. + util::StatusOr PeekNextDma() const override { + return backing_scheduler_->PeekNextDma(); + } + util::StatusOr GetNextDma() override { + return backing_scheduler_->GetNextDma(); + } + util::Status NotifyDmaCompletion(DmaInfo *dma_info) override { + return backing_scheduler_->NotifyDmaCompletion(dma_info); + } + bool IsEmpty() const override { + return backing_scheduler_->IsEmpty(); + } + int64 MaxRemainingCycles() const override { + return backing_scheduler_->MaxRemainingCycles(); + } + util::StatusOr> GetOldestActiveRequest() + const override { + return backing_scheduler_->GetOldestActiveRequest(); + } + + // Enters/leaves real-time mode. Note timing is preserved across toggling. + void SetRealtimeMode(bool on) LOCKS_EXCLUDED(mutex_); + + // Clears all timing information. + void ResetTiming() LOCKS_EXCLUDED(mutex_) { + StdMutexLock lock(&mutex_); + inference_timings_.clear(); + } + + // Sets expected arrival rates, max execution time and tolerance (in + // milliseconds) for an executable reference. + // -1 in any of the fields of api::Timing means keeping that individual value + // unchanged but updating the rest. + util::Status SetExecutableTiming(const ExecutableReference *executable, + const api::Timing &timing) + LOCKS_EXCLUDED(mutex_); + + // Removes timing information for a registered model. + util::Status RemoveExecutableTiming(const ExecutableReference *executable) + LOCKS_EXCLUDED(mutex_); + + // Returns the arrival rate and FPS of a given executable reference. + util::StatusOr GetExecutableTiming( + const ExecutableReference *executable) const LOCKS_EXCLUDED(mutex_); + + // Returns the arrival time of last request for a given executable reference. + int64 GetLastArrivalTime(const ExecutableReference *executable) const + LOCKS_EXCLUDED(mutex_); + + util::StatusOr GetExecutableFPS( + const ExecutableReference *executable) const { + ASSIGN_OR_RETURN(const auto timing, GetExecutableTiming(executable)); + return timing.fps; + } + + util::StatusOr GetExecutableMaxExecutionTimeMs( + const ExecutableReference *executable) const { + ASSIGN_OR_RETURN(const auto timing, GetExecutableTiming(executable)); + return timing.max_execution_time_ms; + } + + util::StatusOr GetExecutableToleranceMs( + const ExecutableReference *executable) const { + ASSIGN_OR_RETURN(const auto timing, GetExecutableTiming(executable)); + return timing.tolerance_ms; + } + + private: + // Tracks timing requirements and statistics for a registered executable. + // See go/darwinn-qos-design on the algorithm. + struct TimingInternal : public api::Timing { + TimingInternal() : api::Timing() {} + explicit TimingInternal(api::Timing timing) : api::Timing{timing} {} + + // Returns max execution time in microseconds. + int64 max_execution_time_us() const { return max_execution_time_ms * 1000; } + + // Returns tolerance in microseconds. + int64 tolerance_us() const { return tolerance_ms * 1000; } + + // Returns per frame time in microseconds, or error when FPS == 0. + util::StatusOr frame_time_us() const { + if (fps == 0) { + return util::InvalidArgumentError( + "Can't calculate frame time of 0 FPS"); + } + return 1e6 / fps; + } + + // Returns if a timing configuration is real-time. + bool HasRealTimeRequirements() const { return fps > 0; } + + int64 last_arrival_time_us{0}; + int64 last_completion_time_us{0}; + }; + + // The underlying single queue DMA scheduler. Thread-safe by itself. + std::unique_ptr backing_scheduler_; + + // Time-stamper for tracking submission and completion time for requests. + std::unique_ptr time_stamper_; + + // Tracks registered executables. + std::unordered_map + inference_timings_ GUARDED_BY(mutex_); + + // Real-time mode? In non-real-time mode, this scheduler behaves the same as + // the underlying scheduler (single queue DMA). + bool real_time_mode_ GUARDED_BY(mutex_){false}; + + // Currently booked time by all scheduled inferences. + int64 time_booked_us_ GUARDED_BY(mutex_){0}; + + // Guards the inference timings. + mutable std::mutex mutex_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_SINGLE_QUEUE_DMA_SCHEDULER_H_ diff --git a/driver/registers/BUILD b/driver/registers/BUILD new file mode 100644 index 0000000..e2d6031 --- /dev/null +++ b/driver/registers/BUILD @@ -0,0 +1,39 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Register access related functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "registers", + srcs = ["registers.cc"], + hdrs = ["registers.h"], + deps = ["//port"], +) + +cc_library( + name = "socket_registers", + srcs = ["socket_registers.cc"], + hdrs = ["socket_registers.h"], + deps = [ + ":registers", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) diff --git a/driver/registers/registers.cc b/driver/registers/registers.cc new file mode 100644 index 0000000..711e052 --- /dev/null +++ b/driver/registers/registers.cc @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/registers/registers.h" + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/status_macros.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status Registers::Poll(uint64 offset, uint64 expected_value, + int64 timeout_us) { + return SpinReadHelper(offset, expected_value, timeout_us, + [this](uint64 offset) { return Read(offset); }); +} + +util::Status Registers::Poll32(uint64 offset, uint32 expected_value, + int64 timeout_us) { + return SpinReadHelper(offset, expected_value, timeout_us, + [this](uint64 offset) { return Read32(offset); }); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/registers/registers.h b/driver/registers/registers.h new file mode 100644 index 0000000..6fbb2e6 --- /dev/null +++ b/driver/registers/registers.h @@ -0,0 +1,97 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_REGISTERS_REGISTERS_H_ +#define DARWINN_DRIVER_REGISTERS_REGISTERS_H_ + +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/time.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interface for CSR access. +class Registers { + public: + // To indicate the polling functions should poll forever. + static constexpr int64 kInfiniteTimeout = -1; + + virtual ~Registers() = default; + + // Open / Close the register interface. + virtual util::Status Open() = 0; + virtual util::Status Close() = 0; + + // Write / Read from a register at the given 64 bit aligned offset. + // Offset may be implementation dependent. + virtual util::Status Write(uint64 offset, uint64 value) = 0; + virtual util::StatusOr Read(uint64 offset) = 0; + + // Polls the specified register until it has the given value. + util::Status Poll(uint64 offset, uint64 expected_value) { + return Poll(offset, expected_value, kInfiniteTimeout); + } + + // Polls the sepcified register until it has the given value or until it + // takes longer than the provided timeout in microseconds. + // Polls forever if timeout is zero or negative. Same behavior as no timeout + // version. + virtual util::Status Poll(uint64 offset, uint64 expected_value, + int64 timeout_us); + + // 32-bit version of above. Usually, it is same as running 64 bit version. + virtual util::Status Write32(uint64 offset, uint32 value) = 0; + virtual util::StatusOr Read32(uint64 offset) = 0; + util::Status Poll32(uint64 offset, uint32 expected_value) { + return Poll32(offset, expected_value, kInfiniteTimeout); + } + virtual util::Status Poll32(uint64 offset, uint32 expected_value, + int64 timeout_us); + + protected: + // Helper function for spin reads a register until received expected value or + // reached timeout. Polls forever if timeout is zero or negative. + // Subclass to use this by passing the specific register read function. + template + util::Status SpinReadHelper(uint64 offset, IntType expected_value, + int64 timeout_us, const FuncType& read_func) { + int64 start_time_us, end_time_us; + if (timeout_us > 0) { + start_time_us = GetCurrentTimeMicros(); + } + + ASSIGN_OR_RETURN(auto actual_value, read_func(offset)); + while (actual_value != expected_value) { + if (timeout_us > 0) { + end_time_us = GetCurrentTimeMicros(); + if (end_time_us - start_time_us > timeout_us) { + return util::DeadlineExceededError("Register poll timeout."); + } + } + ASSIGN_OR_RETURN(actual_value, read_func(offset)); + } + return util::OkStatus(); + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_REGISTERS_REGISTERS_H_ diff --git a/driver/registers/socket_registers.cc b/driver/registers/socket_registers.cc new file mode 100644 index 0000000..278435e --- /dev/null +++ b/driver/registers/socket_registers.cc @@ -0,0 +1,127 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/registers/socket_registers.h" + +#include +#include +#include +#include +#include + +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +SocketRegisters::SocketRegisters(const std::string& ip_address, int port) + : ip_address_(ip_address), port_(port) {} + +SocketRegisters::~SocketRegisters() { + if (socket_fd_ != -1) { + LOG(WARNING) + << "Destroying SocketRegisters - Close() has not yet been called!"; + util::Status status = Close(); + if (!status.ok()) { + LOG(ERROR) << status; + } + } +} + +util::Status SocketRegisters::Open() { + StdMutexLock lock(&mutex_); + if (socket_fd_ != -1) { + return util::FailedPreconditionError("Socket already open."); + } + + VLOG(1) << StringPrintf("Opening socket at %s:%d", ip_address_.c_str(), + port_); + + if ((socket_fd_ = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + return util::UnavailableError(StringPrintf("socket failed (%d).", errno)); + } + + // Clean up on error. + auto socket_closer = MakeCleanup( + [this]() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { close(socket_fd_); }); + + // Setup server address. + sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port_); + + if (inet_pton(AF_INET, ip_address_.c_str(), &server_addr.sin_addr) <= 0) { + return util::FailedPreconditionError( + StringPrintf("Invalid ip address: %s", ip_address_.c_str())); + } + + // Make connection. + if (connect(socket_fd_, reinterpret_cast(&server_addr), + sizeof(server_addr)) < 0) { + return util::UnavailableError(StringPrintf("connect failed (%d).", errno)); + } + + socket_closer.release(); + return util::Status(); // OK +} + +util::Status SocketRegisters::Close() { + StdMutexLock lock(&mutex_); + if (socket_fd_ == -1) { + return util::FailedPreconditionError("Socket already closed."); + } + + close(socket_fd_); + return util::Status(); // OK +} + +util::Status SocketRegisters::Write(uint64 offset, uint64 value) { + VLOG(2) << StringPrintf( + "Register write 0x%llx to 0x%llx", + static_cast(value), // NOLINT(runtime/int) + static_cast(offset)); // NOLINT(runtime/int) + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(Send('w')); + RETURN_IF_ERROR(Send(offset)); + RETURN_IF_ERROR(Send(value)); + return util::Status(); // OK +} + +util::StatusOr SocketRegisters::Read(uint64 offset) { + VLOG(2) << StringPrintf( + "Register read from 0x%llx", + static_cast(offset)); // NOLINT(runtime/int) + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(Send('r')); + RETURN_IF_ERROR(Send(offset)); + uint64 value; + if (recv(socket_fd_, &value, sizeof(value), MSG_WAITALL) < 0) { + return util::UnavailableError(StringPrintf("recv failed (%d).", errno)); + } + return value; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/registers/socket_registers.h b/driver/registers/socket_registers.h new file mode 100644 index 0000000..2ef537f --- /dev/null +++ b/driver/registers/socket_registers.h @@ -0,0 +1,91 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_REGISTERS_SOCKET_REGISTERS_H_ +#define DARWINN_DRIVER_REGISTERS_SOCKET_REGISTERS_H_ + +#include +#include +#include // NOLINT +#include + +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" +#include "port/stringprintf.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Socket implementation of the register interface that sends requests through +// socket and receives the results back through socket. +// +// Commands are sent as following: +// 1. 'r' or 'w' depending on read/write. +// 2. Offset for both read/write. +// 3. If write, value to write. +class SocketRegisters : public Registers { + public: + SocketRegisters(const std::string& ip_address, int port); + ~SocketRegisters() override; + + // This class is neither copyable nor movable. + SocketRegisters(const SocketRegisters&) = delete; + SocketRegisters& operator=(const SocketRegisters&) = delete; + + // Overrides from registers.h + util::Status Open() LOCKS_EXCLUDED(mutex_) override; + util::Status Close() LOCKS_EXCLUDED(mutex_) override; + util::Status Write(uint64 offset, uint64 value) + LOCKS_EXCLUDED(mutex_) override; + util::StatusOr Read(uint64 offset) LOCKS_EXCLUDED(mutex_) override; + util::Status Write32(uint64 offset, uint32 value) + LOCKS_EXCLUDED(mutex_) override { + return Write(offset, value); + } + util::StatusOr Read32(uint64 offset) LOCKS_EXCLUDED(mutex_) override { + return Read(offset); + } + + private: + template + util::Status Send(const T& message) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (send(socket_fd_, &message, sizeof(message), /*flags=*/0) < 0) { + return util::UnavailableError(StringPrintf("send failed (%d).", errno)); + } + return util::Status(); // OK + } + + // IP address. + const std::string ip_address_; + + // Port number. + const int port_; + + // Mutex that guards socket_fd_; + std::mutex mutex_; + + // Socket descriptor. + int socket_fd_ GUARDED_BY(mutex_){-1}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_REGISTERS_SOCKET_REGISTERS_H_ diff --git a/driver/request.cc b/driver/request.cc new file mode 100644 index 0000000..2c41b0b --- /dev/null +++ b/driver/request.cc @@ -0,0 +1,368 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/request.h" + +#include "api/request.h" +#include "port/math_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/time.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { +Request::Request(int id, const PackageReference& package_ref, + const TimeStamper& timestamper) + : id_(id), + package_ref_(package_ref), + main_executable_ref_(*package_ref.MainExecutableReference()), + hardware_batch_size_(package_ref.MainExecutableReference()->BatchSize()), + current_time_(timestamper) { + timing_.created_ns = timestamper.GetTimeNanoSeconds(); + timing_.submitted_ns = -1; + timing_.completed_ns = -1; + } + +util::Status Request::AddInput(const std::string& name, const Buffer& input) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kInitial)); + + RETURN_IF_ERROR(main_executable_ref_.ValidateInput(name, input)); + VLOG(3) << StringPrintf("Adding input \"%s\" with %zu bytes.", name.c_str(), + input.size_bytes()); + inputs_[name].push_back(input); + return util::OkStatus(); +} + +util::Status Request::AddOutput(const std::string& name, const Buffer output) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kInitial)); + + RETURN_IF_ERROR(main_executable_ref_.ValidateOutput(name, output)); + VLOG(3) << StringPrintf("Adding output \"%s\" with %zu bytes.", name.c_str(), + output.size_bytes()); + outputs_[name].push_back(output); + return util::OkStatus(); +} + +util::Status Request::SetPriority(int priority) { + if (priority < 0) { + return util::InvalidArgumentError(StringPrintf( + "Priority must be 0 or greater. %d was provided.", priority)); + } + + StdMutexLock lock(&mutex_); + priority_ = priority; + return util::OkStatus(); +} + +util::StatusOr Request::GetTiming() const { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kDone)); + return timing_; +} + +int Request::GetPriority() const { + StdMutexLock lock(&mutex_); + return priority_; +} + +util::Status Request::SetDone(Done done) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kInitial)); + + if (done_) { + return util::InvalidArgumentError("Done callback is already set."); + } + + done_ = std::move(done); + return util::OkStatus(); +} + +util::Status Request::Prepare() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kInitial)); + + if (!done_) { + return util::InvalidArgumentError("Done callback is not set."); + } + + // Batch size is inferred from the number of input and output buffers provided + // for each input and output layer. There are special cases where an + // executable may have no inputs and outputs (e.g. test executables) in which + // case we assume batch size 1. + if (main_executable_ref_.NumInputLayers() == 0 && + main_executable_ref_.NumOutputLayers() == 0) { + request_batch_size_ = 1; + required_tpu_request_count_ = 1; + pending_tpu_requests_ = 1; + return SetState(kPrepared); + } + + int batch_size = -1; + + for (const auto& name : main_executable_ref_.InputLayerNames()) { + if (inputs_.find(name) == inputs_.end()) { + return util::InvalidArgumentError( + StringPrintf("Unable to find input for layer %s.", name.c_str())); + } + + if (batch_size == -1) { + batch_size = inputs_[name].size(); + continue; + } + + if (inputs_[name].size() != batch_size) { + return util::InvalidArgumentError( + StringPrintf("Mismatched number of input buffers for \"%s\". " + "expected=%d, actual=%zu.", + name.c_str(), batch_size, inputs_[name].size())); + } + } + + for (const auto& name : main_executable_ref_.OutputLayerNames()) { + if (outputs_.find(name) == outputs_.end()) { + return util::InvalidArgumentError( + StringPrintf("Unable to find output for layer %s.", name.c_str())); + } + + if (batch_size == -1) { + batch_size = outputs_[name].size(); + continue; + } + + if (outputs_[name].size() != batch_size) { + return util::InvalidArgumentError( + StringPrintf("Mismatched number of output buffers for \"%s\". " + "expected=%d, actual=%zu.", + name.c_str(), batch_size, outputs_[name].size())); + } + } + + if (batch_size <= 0) { + return util::InvalidArgumentError("No input/output buffers found."); + } + + request_batch_size_ = batch_size; + required_tpu_request_count_ = + MathUtil::CeilOfRatio(request_batch_size_, hardware_batch_size_); + pending_tpu_requests_ = required_tpu_request_count_; + + VLOG(2) << StringPrintf( + "Request prepared, total batch size: %d, total TPU requests required: " + "%d.", + request_batch_size_, required_tpu_request_count_); + return SetState(kPrepared); +} + +util::StatusOr Request::RemainingTpuRequestCount() const { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kPrepared)); + return required_tpu_request_count_ - tpu_requests_prepared_; +} + +util::Status Request::PrepareTpuRequest( + std::shared_ptr tpu_request) { + TRACE_SCOPE("Request::PrepareTpuRequest"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kPrepared)); + + if (main_executable_ref_.NumInputLayers() == 0 && + main_executable_ref_.NumOutputLayers() == 0) { + return PrepareNoIORequest(tpu_request); + } else { + return PrepareIORequest(tpu_request); + } +} + +util::Status Request::PrepareNoIORequest( + std::shared_ptr tpu_request) { + TRACE_SCOPE("Request::PrepareNoIORequest"); + if (request_batch_size_ != 1) { + return util::InvalidArgumentError( + StringPrintf("Executable batch size is 1, yet %d sets of input/outputs " + "are provided.", + request_batch_size_)); + } + + if (tpu_requests_prepared_ >= 1) { + return util::FailedPreconditionError( + StringPrintf("%d are already prepared yet prepare was called again.", + tpu_requests_prepared_)); + } + + auto done = [this](int id, const util::Status& status) { + TpuRequestDone(id, status); + }; + RETURN_IF_ERROR(tpu_request->SetDone(std::move(done))); + + tpu_requests_prepared_ = 1; + return util::OkStatus(); +} + +util::Status Request::PrepareIORequest( + std::shared_ptr tpu_request) { + TRACE_SCOPE("Request::PrepareIORequest"); + if (tpu_requests_prepared_ >= required_tpu_request_count_) { + return util::InternalError( + StringPrintf("Software batch (expected size=%d, actual size=%d) " + "already saturated with prepared TPU requests", + required_tpu_request_count_, tpu_requests_prepared_)); + } + + for (int j = 0; j < hardware_batch_size_; ++j) { + const int buffer_index = tpu_requests_prepared_ * hardware_batch_size_ + j; + if (buffer_index >= request_batch_size_) { + CHECK_EQ(tpu_requests_prepared_ + 1, required_tpu_request_count_); + break; + } + + for (const auto& name : main_executable_ref_.InputLayerNames()) { + RETURN_IF_ERROR( + tpu_request->AddInput(name, inputs_.at(name)[buffer_index])); + } + + for (const auto& name : main_executable_ref_.OutputLayerNames()) { + RETURN_IF_ERROR( + tpu_request->AddOutput(name, outputs_.at(name)[buffer_index])); + } + } + + auto done = [this](int id, const util::Status& status) { + TpuRequestDone(id, status); + }; + RETURN_IF_ERROR(tpu_request->SetDone(std::move(done))); + + // In order not to confuse the TPU, if the last TpuRequest does not have + // enough input/outputs to support the entire native batch size, add dummy + // ones to break even. + if (tpu_requests_prepared_ + 1 == required_tpu_request_count_) { + const int num_noop_buffers = + (required_tpu_request_count_ * hardware_batch_size_) - + request_batch_size_; + if (num_noop_buffers > 0) { + for (const auto& name : main_executable_ref_.InputLayerNames()) { + RETURN_IF_ERROR(tpu_request->AddNoopInputs(name, num_noop_buffers)); + } + for (const auto& name : main_executable_ref_.OutputLayerNames()) { + RETURN_IF_ERROR(tpu_request->AddNoopOutputs(name, num_noop_buffers)); + } + } + } + + ++tpu_requests_prepared_; + return util::OkStatus(); +} + +void Request::NotifySubmission(TpuRequest::RequestType type) { + StdMutexLock lock(&mutex_); + auto time_now = current_time_.GetTimeNanoSeconds(); + if (timing_.submitted_ns == -1) { + timing_.submitted_ns = time_now; // Update parent submission time. + } + timing_.detail_timing.push_back(TimingEvent( + time_now, type, api::Request::TimingEvent::EventType::SUBMITTED)); +} + +void Request::NotifyCompletion(TpuRequest::RequestType type) { + StdMutexLock lock(&mutex_); + // Update parent completion time. + timing_.completed_ns = current_time_.GetTimeNanoSeconds(); + timing_.detail_timing.push_back( + TimingEvent(timing_.completed_ns, type, + api::Request::TimingEvent::EventType::COMPLETED)); +} + +void Request::TpuRequestDone(int id, const util::Status& status) { + // TODO Improve handling of this error. + CHECK_OK(HandleTpuRequestsDone(status, 1)); +} + +util::Status Request::HandleTpuRequestsDone(const util::Status& status, + int num_requests_done) { + Done done; + int64 request_id; + util::Status done_status; + + { + StdMutexLock lock(&mutex_); + // TODO Improve handling of this error. + RETURN_IF_ERROR(ValidateState(kPrepared)); + + if (num_requests_done > pending_tpu_requests_) { + return util::InternalError( + StringPrintf("Number of done requests (%d) exceeds number of pending " + "requests (%d).", + num_requests_done, pending_tpu_requests_)); + } + + pending_tpu_requests_ -= num_requests_done; + done_status_.Update(status); + if (pending_tpu_requests_ > 0) { + return util::OkStatus(); + } + + RETURN_IF_ERROR(SetState(kDone)); + + done = std::move(done_); + done_ = nullptr; + request_id = id_; + done_status = done_status_; + } + + done(request_id, done_status); + return util::OkStatus(); +} + +util::Status Request::SetState(State next_state) { + switch (state_) { + case kInitial: + if (next_state == kPrepared) { + state_ = next_state; + return util::OkStatus(); + } + break; + + case kPrepared: + if (next_state == kDone) { + state_ = next_state; + return util::OkStatus(); + } + break; + + case kDone: + return util::FailedPreconditionError( + StringPrintf("Cannot set state from done to %d.", next_state)); + } + + // Illegal state transition. + return util::FailedPreconditionError(StringPrintf( + "Invalid state transition. current=%d, next=%d.", state_, next_state)); +} + +util::Status Request::ValidateState(State state) const { + if (state_ != state) { + return util::FailedPreconditionError( + StringPrintf("Invalid state. Expected=%d, Actual=%d.", state, state_)); + } + return util::OkStatus(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/request.h b/driver/request.h new file mode 100644 index 0000000..ad6dadc --- /dev/null +++ b/driver/request.h @@ -0,0 +1,214 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_REQUEST_H_ +#define DARWINN_DRIVER_REQUEST_H_ + +#include "api/request.h" +#include "driver/time_stamper/time_stamper.h" +#include "driver/tpu_request.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// This class represents a top level inference request that is created by the +// runtime user. It may have an arbitrary batch size. Its responsibility is to +// sanity check the request and populate TPU requests that can be sent to the +// device as well as tracking their completion. +// +// This is a stateful class. Here's the execution pattern. +// 1. Construction (state: kInitial) +// 2. AddInput() and AddOutput() (can be multiple times) from API interface. +// (state: kInitial) +// 3. SetDone() in runtime (state: kInitial) +// 4. Prepare() in runtime (kInitial state changes to kPrepared). +// 5. PrepareTpuRequest() in runtime as many times as +// RequiredTpuRequestCount() (State: kPrepared). +// 6. Done callback is called when request finishes (state changes from +// // kPrepared to kDone). +class Request : public api::Request { + public: + // Constructs a request provided a unique ID and a reference to the package, + // and an interface to get current timestamps in nanoseconds. + Request(int id, const PackageReference& package_ref, + const TimeStamper& timestamper); + + // This class is not copyable nor movable. + Request(const Request&) = delete; + Request& operator=(const Request&) = delete; + + // Adds an input buffer. Please refer to the API documentation for more info. + util::Status AddInput(const std::string& name, const Buffer& input) override + LOCKS_EXCLUDED(mutex_); + + // Adds an output buffer. Please refer to the API documentation for more info. + util::Status AddOutput(const std::string& name, Buffer output) override + LOCKS_EXCLUDED(mutex_); + + util::Status SetPriority(int priority) override LOCKS_EXCLUDED(mutex_); + + // Returns the unique ID of this request. + int id() const override { return id_; } + + // Returns the timing information of this request. Please refer to the API + // documentation for more info. + util::StatusOr GetTiming() const override LOCKS_EXCLUDED(mutex_); + + // Returns a reference to the executable this request belongs to. + const ExecutableReference& MainExecutableReference() const { + return main_executable_ref_; + } + + const PackageReference& GetPackageReference() const { return package_ref_; } + + int GetPriority() const LOCKS_EXCLUDED(mutex_); + + // Sets the done callback function. This function is called the request has + // finished execution. + util::Status SetDone(Done done) LOCKS_EXCLUDED(mutex_); + + // Prepares the request to be broken down to TPU requests. This should be + // called after we are through adding input/outputs, and have called the + // SetDone() function. + util::Status Prepare() LOCKS_EXCLUDED(mutex_); + + // Returns the number of TPU requests that are needed to be prepared and + // submitted for this request to be fully carried out. + util::StatusOr RemainingTpuRequestCount() const LOCKS_EXCLUDED(mutex_); + + // Sets the input/output buffers and callback of the provided TPU request + // based on the input/output buffers in this request. Can only be called after + // Prepare(). It needs to be called as many times as RequiredTpuRequestCount() + // to ensure that TPU requests for all batch elements are created. + util::Status PrepareTpuRequest(std::shared_ptr tpu_request) + LOCKS_EXCLUDED(mutex_); + + // Notifies the request that a part (or all) of it has been submitted to the + // hardware. + void NotifySubmission(TpuRequest::RequestType) LOCKS_EXCLUDED(mutex_); + + // Notifies the request that a part (or all) of it has completed execution on + // the hardware. + void NotifyCompletion(TpuRequest::RequestType) LOCKS_EXCLUDED(mutex_); + + // Number of estimated cycles it takes for a single TpuRequest of this request + // to take in order to run on TPU (only applies to execution requests, and not + // parameter caching). + int64 EstimatedCyclesPerInference() const { + return GetPackageReference().MainExecutableReference()->EstimatedCycles(); + } + + // Marks num_requests_done pending TpuRequests of this request as done with + // the provided status. It executes the done callback if all TPU requests are + // done at this point. + util::Status HandleTpuRequestsDone(const util::Status& status, + int num_requests_done) + LOCKS_EXCLUDED(mutex_); + + private: + // An enum to specify the state of a request. + enum State { + kInitial, // Input and outputs are still being added. + kPrepared, // Buffers are all added, done callback is set, and Prepare() + // function is complete. + kDone, // All TPU requests are finished. + }; + + // Sets the state of the request. Returns an error for an illegal transition. + util::Status SetState(State next_state) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Verifies that the current state is equal to the provided state. + util::Status ValidateState(State state) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Prepares a single TPU request for a request that has no input/outputs. + util::Status PrepareNoIORequest(std::shared_ptr tpu_request) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Sets the input/output buffers and callback of the provided TPU request + // based on the input/output buffers in this request. + util::Status PrepareIORequest(std::shared_ptr tpu_request) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Gets called on every TPU request callback. + void TpuRequestDone(int id, const util::Status& status) + LOCKS_EXCLUDED(mutex_); + + // The unique ID of this request. + const int id_; + + // A reference to the package this request is tied to. + const PackageReference& package_ref_; + + // The main executable reference this request needs to execute. + const ExecutableReference& main_executable_ref_; + + // Number of individual inferences that can be run in a single request to TPU. + // This is also referred to as data-parallelism. + const int hardware_batch_size_; + + // Maintains integrity of the request object. + mutable std::mutex mutex_; + + // Current state of the request. + State state_ GUARDED_BY(mutex_) = kInitial; + + // The batch size of this request (no batching = 1). This field is valid only + // on kPrepared state and after. + int request_batch_size_ GUARDED_BY(mutex_); + + // Number of requests that runtime needs to make to TPU in order to process + // the entire request_batch_size_. This field is valid only on kPrepared state + // and after. + int required_tpu_request_count_ GUARDED_BY(mutex_); + + // All input buffers in this request (name->batch_index->buffer). + Buffer::NamedMap inputs_ GUARDED_BY(mutex_); + + // All output buffers in this request (name->batch_index->buffer). + Buffer::NamedMap outputs_ GUARDED_BY(mutex_); + + // Final request completion callback. + Done done_ GUARDED_BY(mutex_); + + // Number of tpu requests we are waiting for to finish. + int pending_tpu_requests_ GUARDED_BY(mutex_) = 0; + + // Stores the request done status. Each tpu_request done status updates this. + util::Status done_status_ GUARDED_BY(mutex_); + + // Gets the current time in nanoseconds. + const TimeStamper& current_time_; + + // Timing information of this request. + Timing timing_; + + // The scheduling priority of this request with respect to others. 0 is + // highest priority and the larger the number the lower the priority. Negative + // priorities are invalid. + int priority_ GUARDED_BY(mutex_) = 0; + + // Number of tpu requests that are already prepared. This field will max out + // on required_tpu_request_count_ and only after then the entire request will + // be completed. + int tpu_requests_prepared_ GUARDED_BY(mutex_) = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_REQUEST_H_ diff --git a/driver/run_controller.cc b/driver/run_controller.cc new file mode 100644 index 0000000..61cd31e --- /dev/null +++ b/driver/run_controller.cc @@ -0,0 +1,281 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/run_controller.h" + +#include "driver/config/common_csr_helper.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status_macros.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +RunController::RunController(const config::ChipConfig& config, + Registers* registers) + : scalar_core_csr_offsets_(config.GetScalarCoreCsrOffsets()), + tile_config_csr_offsets_(config.GetTileConfigCsrOffsets()), + tile_csr_offsets_(config.GetTileCsrOffsets()), + has_thread_csr_offsets_(config.HasThreadCsrOffsets()), + tile_thread_0_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread0CsrOffsets() : nullptr), + tile_thread_1_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread1CsrOffsets() : nullptr), + tile_thread_2_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread2CsrOffsets() : nullptr), + tile_thread_3_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread3CsrOffsets() : nullptr), + tile_thread_4_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread4CsrOffsets() : nullptr), + tile_thread_5_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread5CsrOffsets() : nullptr), + tile_thread_6_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread6CsrOffsets() : nullptr), + tile_thread_7_csr_offsets_(config.HasThreadCsrOffsets() ? + &config.GetTileThread7CsrOffsets() : nullptr), + registers_(registers) { + CHECK(registers != nullptr); +} + +util::Status RunController::DoRunControl(RunControl run_state) { + // Value of offset when register is not present in a project. + constexpr uint64 kInvalidOffset = static_cast(-1); + + const uint64 run_state_value = static_cast(run_state); + if (scalar_core_csr_offsets_.scalarCoreRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.scalarCoreRunControl, run_state_value)); + } else { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.scalarDatapath_0RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.avDataPopRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.avDataPopRunControl, run_state_value)); + } else { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.avDataPop_0RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.parameterPopRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.parameterPopRunControl, run_state_value)); + } else { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.parameterPop_0RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.infeedRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.infeedRunControl, run_state_value)); + } else { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.infeed_0_0RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.outfeedRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.outfeedRunControl, run_state_value)); + } else { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.outfeed_0_0RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.infeed1RunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.infeed1RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.infeed_0_1RunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.infeed_0_1RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.outfeed1RunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.outfeed1RunControl, run_state_value)); + } + if (scalar_core_csr_offsets_.outfeed_0_1RunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + scalar_core_csr_offsets_.outfeed_0_1RunControl, run_state_value)); + } + + // TODO: helper uses 7-bits as defined by CSR. Extract bitwidth + // automatically for different chips. + config::registers::TileConfig<7> helper; + helper.set_broadcast(); + RETURN_IF_ERROR( + registers_->Write(tile_config_csr_offsets_.tileconfig0, helper.raw())); + + // Wait until tileconfig0 is set correctly. Subsequent writes are going to + // tiles, but hardware does not guarantee correct ordering with previous + // write. + // TODO + RETURN_IF_ERROR( + registers_->Poll(tile_config_csr_offsets_.tileconfig0, helper.raw())); + if (tile_csr_offsets_.opRunControl != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_0 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_0, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_1 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_1, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_2 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_2, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_3 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_3, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_4 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_4, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_5 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_5, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_6 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_6, run_state_value)); + } + if (tile_csr_offsets_.opRunControl_7 != kInvalidOffset) { + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.opRunControl_7, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write(tile_csr_offsets_.narrowToWideRunControl, + run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_0 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_0, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_1 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_1, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_2 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_2, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_3 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_3, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_4 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_4, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_5 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_5, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_6 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_6, run_state_value)); + } + if (tile_csr_offsets_.narrowToWideRunControl_7 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToWideRunControl_7, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write(tile_csr_offsets_.wideToNarrowRunControl, + run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_0 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_0, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_1 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_1, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_2 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_2, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_3 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_3, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_4 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_4, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_5 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_5, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_6 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_6, run_state_value)); + } + if (tile_csr_offsets_.wideToNarrowRunControl_7 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.wideToNarrowRunControl_7, run_state_value)); + } + + const std::vector + tile_thread_csr_offsets = { + tile_thread_0_csr_offsets_, tile_thread_1_csr_offsets_, + tile_thread_2_csr_offsets_, tile_thread_3_csr_offsets_, + tile_thread_4_csr_offsets_, tile_thread_5_csr_offsets_, + tile_thread_6_csr_offsets_, tile_thread_7_csr_offsets_}; + if (has_thread_csr_offsets_) { + for (const auto* tile_thread_csr_offsets_ : tile_thread_csr_offsets) { + if (tile_thread_csr_offsets_->opRunControl_0 != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_thread_csr_offsets_->opRunControl_0, run_state_value)); + } + if (tile_thread_csr_offsets_->narrowToWideRunControl_0 != + kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_thread_csr_offsets_->narrowToWideRunControl_0, run_state_value)); + } + if (tile_thread_csr_offsets_->wideToNarrowRunControl_0 != + kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_thread_csr_offsets_->wideToNarrowRunControl_0, run_state_value)); + } + } + } + + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.meshBus0RunControl, run_state_value)); + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.meshBus1RunControl, run_state_value)); + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.meshBus2RunControl, run_state_value)); + RETURN_IF_ERROR( + registers_->Write(tile_csr_offsets_.meshBus3RunControl, run_state_value)); + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.ringBusConsumer0RunControl, run_state_value)); + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.ringBusConsumer1RunControl, run_state_value)); + RETURN_IF_ERROR(registers_->Write(tile_csr_offsets_.ringBusProducerRunControl, + run_state_value)); + if (tile_csr_offsets_.narrowToNarrowRunControl != kInvalidOffset) { + RETURN_IF_ERROR(registers_->Write( + tile_csr_offsets_.narrowToNarrowRunControl, run_state_value)); + } + + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/run_controller.h b/driver/run_controller.h new file mode 100644 index 0000000..0616bc9 --- /dev/null +++ b/driver/run_controller.h @@ -0,0 +1,66 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_RUN_CONTROLLER_H_ +#define DARWINN_DRIVER_RUN_CONTROLLER_H_ + +#include "driver/config/chip_config.h" +#include "driver/config/scalar_core_csr_offsets.h" +#include "driver/config/tile_config_csr_offsets.h" +#include "driver/config/tile_csr_offsets.h" +#include "driver/hardware_structures.h" +#include "driver/registers/registers.h" +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Controls run states of both scalar core and tiles. +class RunController { + public: + RunController(const config::ChipConfig& config, Registers* registers); + virtual ~RunController() = default; + + // This class is neither copyable nor movable. + RunController(const RunController&) = delete; + RunController& operator=(const RunController&) = delete; + + // Performs run control. + virtual util::Status DoRunControl(RunControl run_state); + + private: + // CSR offsets. + const config::ScalarCoreCsrOffsets& scalar_core_csr_offsets_; + const config::TileConfigCsrOffsets& tile_config_csr_offsets_; + const config::TileCsrOffsets& tile_csr_offsets_; + bool has_thread_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_0_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_1_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_2_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_3_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_4_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_5_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_6_csr_offsets_; + const config::TileThreadCsrOffsets* const tile_thread_7_csr_offsets_; + + // CSR interface. + Registers* const registers_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_RUN_CONTROLLER_H_ diff --git a/driver/scalar_core_controller.cc b/driver/scalar_core_controller.cc new file mode 100644 index 0000000..3612fab --- /dev/null +++ b/driver/scalar_core_controller.cc @@ -0,0 +1,123 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/scalar_core_controller.h" + +#include + +#include "driver/config/common_csr_helper.h" +#include "driver/registers/registers.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// TODO: This should eventually come from some configurations. +constexpr int kNumInterrupts = 4; + +} // namespace + +ScalarCoreController::ScalarCoreController(const config::ChipConfig& config, + Registers* registers) + : hib_user_csr_offsets_(config.GetHibUserCsrOffsets()), + registers_([registers]() { + CHECK(registers != nullptr); + return registers; + }()), + interrupt_controller_(config.GetScalarCoreInterruptCsrOffsets(), + registers, kNumInterrupts) { + interrupt_counts_.resize(kNumInterrupts, 0ULL); +} + +util::Status ScalarCoreController::Open() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/false)); + + // Sets |interrupt_counts_| to initial CSR values. + auto read_result = registers_->Read(hib_user_csr_offsets_.sc_host_int_count); + RETURN_IF_ERROR(read_result.status()); + + driver::config::registers::ScHostIntCount helper; + helper.set_raw(read_result.ValueOrDie()); + + for (int i = 0; i < kNumInterrupts; ++i) { + interrupt_counts_[i] = helper.get_field(i); + } + + open_ = true; + return util::Status(); // OK +} + +util::Status ScalarCoreController::Close() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/true)); + + open_ = false; + return util::Status(); // OK +} + +util::Status ScalarCoreController::EnableInterrupts() { + return interrupt_controller_.EnableInterrupts(); +} + +util::Status ScalarCoreController::DisableInterrupts() { + return interrupt_controller_.DisableInterrupts(); +} + +util::Status ScalarCoreController::ClearInterruptStatus(int id) { + return interrupt_controller_.ClearInterruptStatus(id); +} + +util::StatusOr ScalarCoreController::CheckInterruptCounts(int id) { + { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*open=*/true)); + } + + auto read_result = registers_->Read(hib_user_csr_offsets_.sc_host_int_count); + RETURN_IF_ERROR(read_result.status()); + + driver::config::registers::ScHostIntCount helper; + helper.set_raw(read_result.ValueOrDie()); + + const uint64 new_count = helper.get_field(id); + const uint64 current_count = interrupt_counts_[id]; + interrupt_counts_[id] = new_count; + + if (new_count >= current_count) { + return new_count - current_count; + } + + const uint64 max_counter = + helper.mask_field(id, std::numeric_limits::max()); + + return max_counter - current_count + 1 + new_count; +} + +util::Status ScalarCoreController::ValidateOpenState(bool open) const { + if (open_ != open) { + return util::FailedPreconditionError( + "Invalid state in ScalarCoreController."); + } + return util::Status(); // OK +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/scalar_core_controller.h b/driver/scalar_core_controller.h new file mode 100644 index 0000000..13ce3c9 --- /dev/null +++ b/driver/scalar_core_controller.h @@ -0,0 +1,87 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_SCALAR_CORE_CONTROLLER_H_ +#define DARWINN_DRIVER_SCALAR_CORE_CONTROLLER_H_ + +#include // NOLINT +#include + +#include "driver/config/chip_config.h" +#include "driver/config/hib_user_csr_offsets.h" +#include "driver/interrupt/interrupt_controller.h" +#include "driver/registers/registers.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Controls scalar core. +class ScalarCoreController { + public: + ScalarCoreController(const config::ChipConfig& config, Registers* registers); + + // This class is neither copyable nor movable. + ScalarCoreController(const ScalarCoreController&) = delete; + ScalarCoreController& operator=(const ScalarCoreController&) = delete; + + virtual ~ScalarCoreController() = default; + + // Opens/closes the controller. + virtual util::Status Open() LOCKS_EXCLUDED(mutex_); + virtual util::Status Close() LOCKS_EXCLUDED(mutex_); + + // Enable/disables interrupts. + util::Status EnableInterrupts() LOCKS_EXCLUDED(mutex_); + util::Status DisableInterrupts() LOCKS_EXCLUDED(mutex_); + + // Clears interrupt status register to notify that host has received the + // interrupt. + util::Status ClearInterruptStatus(int id) LOCKS_EXCLUDED(mutex_); + + // Reads and returns scalar core interrupt count register for given |id|. Read + // is destructive in the sense that the second read to the same |id| will + // return 0 assuming that there was no change in the CSR from the hardware + // side. + virtual util::StatusOr CheckInterruptCounts(int id) + LOCKS_EXCLUDED(mutex_); + + private: + // Returns an error if not |open|. + util::Status ValidateOpenState(bool open) const SHARED_LOCKS_REQUIRED(mutex_); + + // CSR offsets. + const config::HibUserCsrOffsets& hib_user_csr_offsets_; + + // CSR interface. + Registers* const registers_; + + // Interrupt controller. + InterruptController interrupt_controller_; + + // Counted interrupts from scalar core. + std::vector interrupt_counts_; + + // Guard/track the open status of ScalarCoreController. + std::mutex mutex_; + bool open_ GUARDED_BY(mutex_) {false}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_SCALAR_CORE_CONTROLLER_H_ diff --git a/driver/single_queue_dma_scheduler.cc b/driver/single_queue_dma_scheduler.cc new file mode 100644 index 0000000..cc37736 --- /dev/null +++ b/driver/single_queue_dma_scheduler.cc @@ -0,0 +1,418 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/single_queue_dma_scheduler.h" + +#include +#include + +#include "api/driver.h" +#include "api/watchdog.h" +#include "driver/tpu_request.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status SingleQueueDmaScheduler::ValidateOpenState(bool is_open) const { + if (is_open_ != is_open) { + return util::FailedPreconditionError( + StringPrintf("Bad state: expected=%d, actual=%d", is_open, is_open_)); + } + return util::Status(); // OK +} + +util::Status SingleQueueDmaScheduler::Open() { + StdMutexLock lock(&mutex_); + if (!IsEmptyLocked()) { + return util::FailedPreconditionError("DMA queues are not empty"); + } + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/false)); + is_open_ = true; + RETURN_IF_ERROR(watchdog_->Deactivate()); + return util::Status(); // OK +} + +util::Status SingleQueueDmaScheduler::Close(api::Driver::ClosingMode mode) { + { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + while (!pending_dmas_.empty()) { + pending_dmas_.pop(); + } + } + + util::Status status; + status.Update(CancelPendingRequests()); + if (mode == api::Driver::ClosingMode::kAsap) { + status.Update(CancelActiveRequests()); + } else { + status.Update(CloseActiveDmas()); + } + + StdMutexLock lock(&mutex_); + is_open_ = false; + return status; +} + +util::Status SingleQueueDmaScheduler::Submit( + std::shared_ptr request) { + TRACE_SCOPE("SingleQueueDmaScheduler::Submit"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + + RETURN_IF_ERROR(request->NotifyRequestSubmitted()); + VLOG(3) << StringPrintf("Request[%d]: Submitted", request->id()); + ASSIGN_OR_RETURN(auto dmas, request->GetDmaInfos()); + pending_tasks_.push_back({std::move(request), std::move(dmas)}); + + return util::Status(); // OK +} + +util::StatusOr SingleQueueDmaScheduler::PeekNextDma() const { + TRACE_SCOPE("SingleQueueDmaScheduler::PeekNextDma"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + if (pending_dmas_.empty() && pending_tasks_.empty()) { + return DmaDescriptorType::kLocalFence; + } + + if (pending_dmas_.empty()) { + return pending_tasks_.front().dmas.front().type(); + } else { + return pending_dmas_.front().info->type(); + } +} + +util::StatusOr SingleQueueDmaScheduler::GetNextDma() { + TRACE_SCOPE("SingleQueueDmaScheduler::GetNextDma"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + if (pending_dmas_.empty() && pending_tasks_.empty()) { + return nullptr; + } + + if (pending_dmas_.empty()) { + auto& task = pending_tasks_.front(); + RETURN_IF_ERROR(task.request->NotifyRequestActive()); + TpuRequest* request = task.request.get(); + for (auto& dma : task.dmas) { + pending_dmas_.push({&dma, request}); + } + active_tasks_.push_back(std::move(task)); + pending_tasks_.pop_front(); + RETURN_IF_ERROR(watchdog_->Activate().status()); + } + + // If fenced, return empty DMAs. + const auto& pending_front = pending_dmas_.front(); + if (pending_front.info->type() == DmaDescriptorType::kLocalFence || + pending_front.info->type() == DmaDescriptorType::kGlobalFence) { + return nullptr; + } + + pending_front.info->MarkActive(); + VLOG(7) << StringPrintf("Request[%d]: Scheduling DMA[%d]", + pending_front.request->id(), + pending_front.info->id()); + + auto* next_dma = pending_front.info; + pending_dmas_.pop(); + return next_dma; +} + +util::Status SingleQueueDmaScheduler::NotifyDmaCompletion(DmaInfo* dma_info) { + TRACE_SCOPE("SingleQueueDmaScheduler::NotifyDmaCompletion"); + if (!dma_info->IsActive()) { + const auto dma_dump = dma_info->Dump(); + return util::FailedPreconditionError( + StringPrintf("Cannot complete inactive DMA: %s", dma_dump.c_str())); + } + + { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + + dma_info->MarkCompleted(); + VLOG(7) << StringPrintf("Completing DMA[%d]", dma_info->id()); + } + + RETURN_IF_ERROR(HandleCompletedTasks()); + + StdMutexLock lock(&mutex_); + wait_active_dmas_complete_.notify_all(); + if (pending_dmas_.empty()) { + return util::Status(); // OK + } + + const auto& pending_front = pending_dmas_.front(); + if (pending_front.info->type() != DmaDescriptorType::kLocalFence) { + return util::Status(); // OK + } + + // Clear local fence if completed. + RETURN_IF_ERROR(HandleActiveTasks()); + if (pending_front.info->IsCompleted()) { + VLOG(7) << StringPrintf("Request[%d]: Local fence done", + pending_front.request->id()); + pending_dmas_.pop(); + } + return util::Status(); // OK +} + +util::Status SingleQueueDmaScheduler::NotifyRequestCompletion() { + TRACE_SCOPE("SingleQueueDmaScheduler::NotifyRequestCompletion"); + + // This region holds the lock since it needs to deal with task and DMA queues. + // As a result, we may need to call NotifyCompletion() on the request for + // which we do not need the lock. In such case that request will be moved to + // request_to_be_notified. + std::shared_ptr request_to_be_notified; + { + StdMutexLock lock(&mutex_); + + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + if (active_tasks_.empty()) { + return util::FailedPreconditionError("No active request to complete"); + } + + // Requests are always handled in FIFO order. + TpuRequest* completed_request = active_tasks_.front().request.get(); + if (!pending_dmas_.empty()) { + const auto& pending_front = pending_dmas_.front(); + + if (pending_front.request == completed_request) { + // Clear global fence if exists. + if (pending_front.info->type() == DmaDescriptorType::kGlobalFence) { + VLOG(7) << StringPrintf("Request[%d]: Global fence done", + completed_request->id()); + pending_front.info->MarkCompleted(); + pending_dmas_.pop(); + } else { + return util::FailedPreconditionError( + StringPrintf("Request[%d] is completing while DMAs are pending.", + completed_request->id())); + } + } + } + + RETURN_IF_ERROR(HandleActiveTasks()); + Task next_front = std::move(active_tasks_.front()); + active_tasks_.pop_front(); + + RETURN_IF_ERROR(watchdog_->Signal()); + if (active_tasks_.empty()) { + RETURN_IF_ERROR(watchdog_->Deactivate()); + } + + if (next_front.dmas.empty() && completed_tasks_.empty()) { + request_to_be_notified = std::move(next_front.request); + } else { + completed_tasks_.push_back(std::move(next_front)); + } + } + + if (request_to_be_notified) { + RETURN_IF_ERROR(request_to_be_notified->NotifyCompletion(util::OkStatus())); + VLOG(3) << StringPrintf("Request[%d]: Completed", + request_to_be_notified->id()); + wait_active_requests_complete_.notify_all(); + } + + return util::OkStatus(); +} + +util::Status SingleQueueDmaScheduler::CancelPendingRequests() { + util::Status status; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + status.Update(CancelTaskQueue(pending_tasks_)); + return status; +} + +util::Status SingleQueueDmaScheduler::WaitActiveRequests() { + TRACE_SCOPE("SingleQueueDmaScheduler::WaitActiveRequests"); + StdCondMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + while (!completed_tasks_.empty() || !active_tasks_.empty()) { + VLOG(3) << StringPrintf("Waiting for %zd more active requests", + completed_tasks_.size() + active_tasks_.size()); + wait_active_requests_complete_.wait(lock); + } + return util::Status(); // OK +} + +int64 SingleQueueDmaScheduler::MaxRemainingCycles() const { + StdMutexLock lock(&mutex_); + int64 cycles = 0; + for (const auto& task : pending_tasks_) { + cycles += task.request->executable_reference().EstimatedCycles(); + } + for (const auto& task : active_tasks_) { + cycles += task.request->executable_reference().EstimatedCycles(); + } + return cycles; +} + +util::Status SingleQueueDmaScheduler::HandleCompletedTasks() { + TRACE_SCOPE("SingleQueueDmaScheduler::HandleCompletedTasks"); + + std::vector> completed_requests; + bool notify = false; + + // We need to lock the mutex for doing queue operations but running request + // callbacks don't need that lock. Therefore we push the completed requests + // into a vector and run their NotifyCompletion methods outside of the lock + // zone. + { + StdMutexLock lock(&mutex_); + + if (completed_tasks_.empty()) { + return util::OkStatus(); + } + + completed_tasks_.front().dmas.remove_if( + [](const DmaInfo& dma_info) { return dma_info.IsCompleted(); }); + + // Complete tasks, whose DMAs are all completed. + while (completed_tasks_.front().dmas.empty()) { + auto& front_task = completed_tasks_.front(); + VLOG(3) << StringPrintf("Request[%d]: Completed", + front_task.request->id()); + completed_requests.push_back(std::move(front_task.request)); + completed_tasks_.pop_front(); + + if (completed_tasks_.empty()) { + notify = true; + break; + } + + completed_tasks_.front().dmas.remove_if( + [](const DmaInfo& dma_info) { return dma_info.IsCompleted(); }); + } + } + + for (auto& request : completed_requests) { + RETURN_IF_ERROR(request->NotifyCompletion(util::OkStatus())); + } + + if (notify) { + wait_active_requests_complete_.notify_all(); + } + + return util::OkStatus(); +} + +util::Status SingleQueueDmaScheduler::HandleActiveTasks() { + TRACE_SCOPE("SingleQueueDmaScheduler::HandleActiveTasks"); + if (active_tasks_.empty()) { + return util::Status(); // OK + } + + auto& front_task = active_tasks_.front(); + front_task.dmas.remove_if( + [](const DmaInfo& dma_info) { return dma_info.IsCompleted(); }); + + if (front_task.dmas.empty()) { + return util::Status(); // OK + } + + auto& front_dma = front_task.dmas.front(); + // If first remaining DMA is local fence, mark it completed. + if (front_dma.type() == DmaDescriptorType::kLocalFence) { + front_dma.MarkCompleted(); + } + return util::Status(); // OK +} + +util::Status SingleQueueDmaScheduler::CloseActiveDmas() { + TRACE_SCOPE("SingleQueueDmaScheduler::CloseActiveDmas"); + StdCondMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + while (!completed_tasks_.empty()) { + completed_tasks_.front().dmas.remove_if( + [](const DmaInfo& info) { return !info.IsActive(); }); + if (completed_tasks_.front().dmas.empty()) { + completed_tasks_.pop_front(); + } + + if (completed_tasks_.empty()) { + break; + } + wait_active_dmas_complete_.wait(lock); + } + while (!active_tasks_.empty()) { + active_tasks_.front().dmas.remove_if( + [](const DmaInfo& info) { return !info.IsActive(); }); + if (active_tasks_.front().dmas.empty()) { + active_tasks_.pop_front(); + RETURN_IF_ERROR(watchdog_->Signal()); + } + + if (active_tasks_.empty()) { + RETURN_IF_ERROR(watchdog_->Deactivate()); + break; + } + + wait_active_dmas_complete_.wait(lock); + } + return util::Status(); // OK +} + +util::Status SingleQueueDmaScheduler::CancelActiveRequests() { + util::Status status; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateOpenState(/*is_open=*/true)); + + status.Update(CancelTaskQueue(active_tasks_)); + status.Update(CancelTaskQueue(completed_tasks_)); + while (!pending_dmas_.empty()) { + pending_dmas_.pop(); + } + + RETURN_IF_ERROR(watchdog_->Deactivate()); + + return status; +} + +util::Status SingleQueueDmaScheduler::CancelTaskQueue(std::deque& tasks) { + util::Status status; + while (!tasks.empty()) { + status.Update(tasks.front().request->Cancel()); + tasks.pop_front(); + } + return status; +} + +util::StatusOr> +SingleQueueDmaScheduler::GetOldestActiveRequest() const { + StdMutexLock lock(&mutex_); + if (active_tasks_.empty()) { + return util::UnknownError( + "No requests active when querying for oldest active request."); + } + + return active_tasks_.front().GetTpuRequest(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/single_queue_dma_scheduler.h b/driver/single_queue_dma_scheduler.h new file mode 100644 index 0000000..230f4da --- /dev/null +++ b/driver/single_queue_dma_scheduler.h @@ -0,0 +1,165 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_SINGLE_QUEUE_DMA_SCHEDULER_H_ +#define DARWINN_DRIVER_SINGLE_QUEUE_DMA_SCHEDULER_H_ + +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include + +#include "api/driver.h" +#include "api/watchdog.h" +#include "driver/dma_info.h" +#include "driver/dma_scheduler.h" +#include "driver/tpu_request.h" +#include "port/status.h" +#include "port/std_mutex_lock.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Manages the processing order of DMAs with single queue. All DMAs are +// serialized. Thread-safe. +class SingleQueueDmaScheduler : public DmaScheduler { + public: + SingleQueueDmaScheduler(std::unique_ptr watchdog) + : watchdog_(std::move(watchdog)) {} + ~SingleQueueDmaScheduler() override = default; + + // Implements DmaScheduler interfaces. + util::Status Open() override LOCKS_EXCLUDED(mutex_); + util::Status Close(api::Driver::ClosingMode mode) override + LOCKS_EXCLUDED(mutex_); + util::Status Submit(std::shared_ptr request) override + LOCKS_EXCLUDED(mutex_); + util::StatusOr PeekNextDma() const override + LOCKS_EXCLUDED(mutex_); + util::StatusOr GetNextDma() override LOCKS_EXCLUDED(mutex_); + util::Status NotifyDmaCompletion(DmaInfo* dma_info) override + LOCKS_EXCLUDED(mutex_); + util::Status NotifyRequestCompletion() override LOCKS_EXCLUDED(mutex_); + util::Status CancelPendingRequests() override LOCKS_EXCLUDED(mutex_); + util::Status WaitActiveRequests() override LOCKS_EXCLUDED(mutex_); + bool IsEmpty() const override LOCKS_EXCLUDED(mutex_) { + StdMutexLock lock(&mutex_); + return IsEmptyLocked(); + } + int64 MaxRemainingCycles() const override LOCKS_EXCLUDED(mutex_); + util::StatusOr> GetOldestActiveRequest() + const override LOCKS_EXCLUDED(mutex_); + + private: + // A data structure for managing Request and associated DMAs. + struct Task { + Task(std::shared_ptr request, std::list&& dmas) + : request(std::move(request)), dmas(std::move(dmas)) {} + + // This type is movable. + Task(Task&& other) + : request(std::move(other.request)), dmas(std::move(other.dmas)) {} + Task& operator=(Task&& other) { + if (this != &other) { + request = std::move(other.request); + dmas = std::move(other.dmas); + } + return *this; + } + + // Returns the associated TpuRequest. + std::shared_ptr GetTpuRequest() const { + return request; + } + + // Request. + std::shared_ptr request; + + // DMAs to be performed to serve request. std::list is intentionally used to + // have valid pointers while other members removed. + std::list dmas; + }; + + // A data structure for keeping track of DMA and its associated request. + struct PendingDma { + // DMA. + DmaInfo* info; + + // Related request. + TpuRequest* request; + }; + + // Validates whether in "is_open" state. + util::Status ValidateOpenState(bool is_open) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Locked version of IsEmpty(). + bool IsEmptyLocked() const EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + return pending_tasks_.empty() && active_tasks_.empty() && + pending_dmas_.empty(); + } + + // Handles all completed DMAs related cleanups for given tasks. + util::Status HandleCompletedTasks() LOCKS_EXCLUDED(mutex_); + util::Status HandleActiveTasks() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Waits until all active DMAs are completed. + util::Status CloseActiveDmas() LOCKS_EXCLUDED(mutex_); + + // Cancels all the active DMAs and requests. + util::Status CancelActiveRequests() LOCKS_EXCLUDED(mutex_); + + // Cancels all the tasks in a provided queue. + util::Status CancelTaskQueue(std::deque& tasks) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Guards all the related queues. + mutable std::mutex mutex_; + + // A notification to wait for all active requests to complete. + std::condition_variable wait_active_requests_complete_; + + // A notification to wait for all active dmas to complete. + std::condition_variable wait_active_dmas_complete_; + + // Tracks open state. + bool is_open_ GUARDED_BY(mutex_){false}; + + // Pending tasks that have not yet performed any DMAs to DarwiNN device. + std::deque pending_tasks_ GUARDED_BY(mutex_); + + // Active tasks that have delivered DMAs fully or partially to DarwiNN device. + std::deque active_tasks_ GUARDED_BY(mutex_); + + // Completed tasks that may have few active on-going DMAs. + std::deque completed_tasks_ GUARDED_BY(mutex_); + + // DMAs belonging to active requests that are not yet served. + std::queue pending_dmas_ GUARDED_BY(mutex_); + + // A watchdog passed down from the driver to keep track of TPU being active. + // DmaScheduler is responsible for activating the watchdog whenever a task + // enters active queue and de-activating it when the queue is empty. + std::unique_ptr watchdog_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_SINGLE_QUEUE_DMA_SCHEDULER_H_ diff --git a/driver/single_tpu_request.cc b/driver/single_tpu_request.cc new file mode 100644 index 0000000..305e11b --- /dev/null +++ b/driver/single_tpu_request.cc @@ -0,0 +1,660 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/single_tpu_request.h" + +#include + +#include "api/allocated_buffer.h" +#include "api/buffer.h" +#include "driver/allocator.h" +#include "driver/executable_util.h" +#include "driver/hardware_structures.h" +#include "driver/instruction_buffers.h" +#include "driver/memory/address_space.h" +#include "driver/package_registry.h" +#include "driver/request.h" +#include "executable/executable_generated.h" +#include "port/array_slice.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/macros.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +using ::flatbuffers::VectorLength; + +SingleTpuRequest::SingleTpuRequest( + int id, const std::shared_ptr parent_request, + const ExecutableReference* executable_reference, Allocator* allocator, + DramAllocator* dram_allocator, + std::unique_ptr device_buffer_mapper, + const DmaInfoExtractor* extractor, uint64 alignment_bytes, Done done, + RequestType type) + : id_(id), + type_(type), + parent_request_(parent_request), + executable_reference_(*[executable_reference]() { + CHECK(executable_reference != nullptr); + return executable_reference; + }()), + allocator_([allocator]() { + CHECK(allocator != nullptr); + return allocator; + }()), + dram_allocator_([dram_allocator]() { + CHECK(dram_allocator != nullptr); + return dram_allocator; + }()), + device_buffer_mapper_(std::move(device_buffer_mapper)), + extractor_(*[extractor]() { + CHECK(extractor != nullptr); + return extractor; + }()), + done_(std::move(done)), + parameter_device_buffer_( + executable_reference_.GetParameterDeviceBuffer()), + alignment_bytes_(alignment_bytes) { + VLOG(5) << StringPrintf("[%d] Request constructed.", id_); +} + +SingleTpuRequest::SingleTpuRequest( + int id, const std::shared_ptr parent_request, + const ExecutableReference* executable_reference, Allocator* allocator, + DramAllocator* dram_allocator, + std::unique_ptr device_buffer_mapper, + const DmaInfoExtractor* extractor, uint64 alignment_bytes, + RequestType type) + : SingleTpuRequest(id, parent_request, executable_reference, allocator, + dram_allocator, std::move(device_buffer_mapper), + extractor, alignment_bytes, + /*done=*/nullptr, type) {} + +SingleTpuRequest::~SingleTpuRequest() { + VLOG(5) << StringPrintf("[%d] Request destroyed.", id_); + CHECK_OK(Cleanup()); +} + +util::Status SingleTpuRequest::SetDone(Done done) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + done_ = std::move(done); + return util::OkStatus(); +} + +util::Status SingleTpuRequest::AddInput(const std::string& name, + const Buffer& user_input) { + TRACE_SCOPE("SingleTpuRequest::AddInput"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + RETURN_IF_ERROR(executable_reference_.ValidateInput(name, user_input)); + VLOG(3) << StringPrintf("Adding input \"%s\" with %zu bytes.", name.c_str(), + user_input.size_bytes()); + + ASSIGN_OR_RETURN(const auto* layer, executable_reference_.InputLayer(name)); + Buffer host_input = user_input; + + // For iterative models, we need to add padding after each iteration. + if (layer->execution_count_per_inference() > 1 && + host_input.size_bytes() != layer->PaddedSizeBytes()) { + if (user_input.IsDramType()) + return util::UnimplementedError( + "DRAM input buffers currently do not support " + "execution_count_per_inference > 1"); + host_input = ScatterInput(user_input, layer); + } + + if (layer->SignedDataType()) { + if (user_input.IsDramType()) + return util::UnimplementedError( + "DRAM input buffers currently do not support " + "signed data type"); + RETURN_IF_ERROR(layer->TransformSignedDataType(host_input)); + } + + // If this buffer needs to be cached on TPU DRAM, we should replace it with a + // DRAM buffer and copy the contents. If DRAM buffer allocation fails, we will + // carry on with the same host DRAM buffer. + if (layer->CacheOnDram() && !user_input.IsDramType()) { + TRACE_SCOPE("SingleTpuRequest::AddInput::AddDRAMBuffer"); + auto buffer_or_error = + dram_allocator_->AllocateBuffer(layer->PaddedSizeBytes()); + if (buffer_or_error.ok()) { + auto dram_buffer = buffer_or_error.ValueOrDie(); + RETURN_IF_ERROR(dram_buffer->ReadFrom(host_input.ptr())); + host_input = Buffer(dram_buffer); + } else { + LOG(WARNING) << StringPrintf( + "Failed to allocate TPU DRAM buffer of size %d: ", + layer->PaddedSizeBytes()) + << buffer_or_error.status().message(); + } + } + + // At this point we are about to add host_input to the list of buffers + // that get mapped to TPU. If it is on host DRAM, we should make sure it is + // aligned, otherwise copy it to an aligned buffer. + if (host_input.IsPtrType() && !IsBufferAligned(host_input)) { + TRACE_SCOPE("SingleTpuRequest::AddInput::CopyForAlignment"); + // From here on, we need to make sure that accessing padding bytes will not + // cause problems, however the input buffer supplied by the user may not + // explicitly include padding bytes. To avoid always copying the input + // buffer, instead we ensure that reading memory slightly past the end of + // what was supplied by the user is safe and not going to page fault. + + // If the provided buffer is aligned, that implies that the padded end is + // also aligned, and therefore the padding bytes cannot cross a page + // boundary. So we can use it directly and avoid paying for a memcpy. + // (Unless we need to pad in between elements for hardware looping support.) + auto aligned_input = allocator_->MakeBuffer(layer->PaddedSizeBytes()); + memcpy(aligned_input.ptr(), host_input.ptr(), host_input.size_bytes()); + host_input = aligned_input; + } + + host_inputs_[name].push_back(host_input); + return util::OkStatus(); +} + +util::Status SingleTpuRequest::AddOutput(const std::string& name, + Buffer output) { + TRACE_SCOPE("SingleTpuRequest::AddOutput"); + + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + RETURN_IF_ERROR(executable_reference_.ValidateOutput(name, output)); + + VLOG(3) << StringPrintf("Adding output \"%s\" with %zu bytes.", name.c_str(), + output.size_bytes()); + ASSIGN_OR_RETURN(const auto* layer, executable_reference_.OutputLayer(name)); + + if (output.IsDramType() && !output.IsManagedType()) { + TRACE_SCOPE("SingleTpuRequest::AddOutput::PushToHostOutput"); + // Handle special case for user-created on-device DRAM buffer. + // 1. Use the user-provided buffer directly for model output. + // 2. There is no separate user-buffer to synchronize output data with. + // 3. There will be no opportunity for post-processing, e.g., re-layout. + // Therefore, we do not accept a user-created on-device DRAM buffer + // that needs post-processing. + + // TODO -- When the proper test is implemented, use it to + // validate that this output buffer does not in fact need + // post-processing. + + host_outputs_[name].push_back(output); + } else { + TRACE_SCOPE("SingleTpuRequest::AddOutput::CreateTmpAndPushToHostOutput"); + // In all other cases, create a temporary buffer in host memory + // for the model output. The temporary output will need to be + // synchronized (potentially after post-processing) with the + // actual user-provided buffer. + + auto host_output = + GetOrCreateBatchOutput(layer, name) + .Slice(user_outputs_[name].size() * layer->PaddedSizeBytes(), + layer->PaddedSizeBytes()); + host_outputs_[name].push_back(std::move(host_output)); + } + + user_outputs_[name].push_back(std::move(output)); + + return util::Status(); // OK +} + +util::Status SingleTpuRequest::AddNoopInputs(const std::string& name, + int count) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + VLOG(3) << StringPrintf("Adding %d noop inputs for layer \"%s\".", count, + name.c_str()); + + ASSIGN_OR_RETURN(const auto* layer, executable_reference_.InputLayer(name)); + auto& inputs = host_inputs_[name]; + inputs.reserve(count); + + auto batch_buffer = CreateActivationBuffer(layer, count); + for (int i = 0; i < count; ++i) { + auto buffer = batch_buffer.Slice(i * layer->PaddedSizeBytes(), + layer->PaddedSizeBytes()); + inputs.push_back(buffer); + } + + return util::OkStatus(); +} + +util::Status SingleTpuRequest::AddNoopOutputs(const std::string& name, + int count) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + VLOG(3) << StringPrintf("Adding %d noop outputs for layer \"%s\".", count, + name.c_str()); + + ASSIGN_OR_RETURN(const auto* layer, executable_reference_.OutputLayer(name)); + auto& outputs = host_outputs_[name]; + outputs.reserve(count); + + const auto& batch_buffer = GetOrCreateBatchOutput(layer, name); + const int total_batches = executable_reference_.BatchSize(); + for (int i = total_batches - count; i < total_batches; ++i) { + auto buffer = batch_buffer.Slice(i * layer->PaddedSizeBytes(), + layer->PaddedSizeBytes()); + outputs.push_back(buffer); + } + + return util::OkStatus(); +} + +util::Status SingleTpuRequest::MapDataBuffers() { + // Map activations except parameters, which is done at registration time. + TRACE_SCOPE("Request::MapDataBuffers"); + RETURN_IF_ERROR( + device_buffer_mapper_->MapScratch(executable_reference_.scratch())); + RETURN_IF_ERROR(device_buffer_mapper_->MapInputs(host_inputs_)); + RETURN_IF_ERROR(device_buffer_mapper_->MapOutputs(host_outputs_)); + return util::Status(); // OK +} + +util::Status SingleTpuRequest::MapInstructionBuffers() { + TRACE_SCOPE("Request::MapInstructionBuffers"); + RETURN_IF_ERROR(device_buffer_mapper_->MapInstructions( + instruction_buffers_->GetBuffers())); + + return util::Status(); // OK +} + +util::Status SingleTpuRequest::Cleanup() { + // Note that these calls are a no-op if request is already in a clean state. + RETURN_IF_ERROR(device_buffer_mapper_->UnmapAll()); + if (instruction_buffers_) { + // Returns the instruction buffers back to executable references, so that + // we could reuse it in the next request. + // This saves time allocating / copying new host memory buffers. + const_cast(executable_reference_) + .ReturnInstructionBuffers(std::move(instruction_buffers_)); + } + + return util::Status(); // OK +} + +util::Status SingleTpuRequest::Validate() { + TRACE_SCOPE("Request::Validate"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + + // Validate instruction bit stream. + if (VectorLength(executable().instruction_bitstreams()) == 0) { + return util::InvalidArgumentError( + "Executable does not contain instruction bitstream."); + } + for (const auto& chunk : *executable().instruction_bitstreams()) { + if (VectorLength(chunk->bitstream()) == 0) { + return util::InvalidArgumentError( + "Executable contains empty instruction bitstream chunk."); + } + } + + // Number of input / outputs should match with executable. + if (host_inputs_.size() != VectorLength(executable().input_layers())) { + return util::InvalidArgumentError( + "Added inputs does not match the number of required inputs for " + "executable."); + } + + if (host_outputs_.size() != VectorLength(executable().output_layers())) { + return util::InvalidArgumentError( + "Added outputs does not match the number of required outputs for " + "executable."); + } + + // Number of input / output buffers must match configured batch size. + for (const auto& name_and_input : host_inputs_) { + if (name_and_input.second.size() != executable().batch_size()) { + return util::InvalidArgumentError( + StringPrintf("Number of input buffers for \"%s\" does not match " + "configured batch size. expected=%d, actual=%zu.", + name_and_input.first.c_str(), executable().batch_size(), + name_and_input.second.size())); + } + } + + for (const auto& name_and_output : host_outputs_) { + if (name_and_output.second.size() != executable().batch_size()) { + return util::InvalidArgumentError( + StringPrintf("Number of output buffers for \"%s\" does not match " + "configured batch size. expected=%d, actual=%zu.", + name_and_output.first.c_str(), executable().batch_size(), + name_and_output.second.size())); + } + } + + return util::Status(); // OK +} + +util::Status SingleTpuRequest::Prepare() { + TRACE_SCOPE("Request::Prepare"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kUninitialized)); + + // Reuses old instruction buffers if available. + // If not this will create new instruction buffers. + if (!instruction_buffers_) { + instruction_buffers_ = + const_cast(executable_reference_) + .GetInstructionBuffers(allocator_); + } + + RETURN_IF_ERROR(MapDataBuffers()); + VLOG(10) << "MapDataBuffers() done."; + + // Update the instruction stream to link the input, output and parameter + // addresses. + instruction_buffers_->LinkInstructionBuffers( + parameter_device_buffer_, device_buffer_mapper_.get(), + *executable().instruction_bitstreams()); + + // Mapping of instruction buffers must happen after instructions have been + // been patched with linked addresses. Any further modifications to + // instructions may not be visible to device due to cache coherency issues. + auto status = MapInstructionBuffers(); + if (!status.ok()) { + status.Update(device_buffer_mapper_->UnmapAll()); + return status; + } + VLOG(10) << "MapInstructionBuffers() done."; + + return SetState(kCreated); +} + +util::Status SingleTpuRequest::NotifyRequestSubmitted() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kCreated)); + + VLOG(3) << StringPrintf("[%d] NotifyRequestSubmitted()", id_); + return SetState(kSubmitted); +} + +util::Status SingleTpuRequest::NotifyRequestActive() { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kSubmitted)); + + VLOG(3) << StringPrintf("[%d] NotifyRequestActive()", id_); + return SetState(kActive); +} + +util::Status SingleTpuRequest::NotifyCompletion(util::Status status) { + TRACE_SCOPE("Request::NotifyCompletion"); + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateState(kActive)); + + // First notify the parent request. This will affect timing measurements so it + // needs to be done first. + parent_request_->NotifyCompletion(type()); + VLOG(3) << StringPrintf("[%d] NotifyCompletion()", id_); + + // Cleanup first before notify, because we need to unmap buffers first to + // guarantee that output buffers are coherent. + status.Update(Cleanup()); + + RETURN_IF_ERROR(PostProcessOutputBuffers()); + + if (done_) { + done_(id_, status); + + // The |done_| callback may be a lambda that directly or indirectly holds a + // shared_ptr to this request. If that happens, we will have a circular + // reference through a shared_ptr, which will cause a memory leak. Prevent + // the leak by explicitly destructing the lambda here. + done_ = nullptr; + } + + return SetState(kDone); +} + +util::StatusOr> SingleTpuRequest::GetDmaInfos() const { + StdMutexLock lock(&mutex_); + if (state_ != kCreated && state_ != kSubmitted) { + return util::FailedPreconditionError( + StringPrintf("Unexpected call to GetDmaInfos in state_ = %d.", state_)); + } + + return extractor_.ExtractDmaInfos(executable_reference_, + *device_buffer_mapper_); +} + +util::Status SingleTpuRequest::Cancel() { + StdMutexLock lock(&mutex_); + + VLOG(3) << StringPrintf("[%d] Cancel()", id_); + + if (state_ == kUninitialized || state_ == State::kCreated) { + return util::FailedPreconditionError( + StringPrintf("Cannot cancel in state_=%d.", state_)); + } + + // If State::kSubmitted, or kActive OK to cancel. + if (state_ == State::kSubmitted || state_ == State::kActive) { + // Run completed callback. + // TODO: Share common code with NotifyCompletion. + if (done_) { + done_(id_, util::CancelledError("Request cancelled.")); + done_ = nullptr; // See above for why this is needed. + } + + RETURN_IF_ERROR(Cleanup()); + return SetState(kDone); + } + + // If State::kDone, do nothing because request is already complete. + return util::Status(); // OK +} + +util::Status SingleTpuRequest::ValidateState(State expected_state) const { + if (state_ != expected_state) { + return util::FailedPreconditionError(StringPrintf( + "Bad request state. expected=%d, actual=%d.", expected_state, state_)); + } + return util::Status(); // OK +} + +util::Status SingleTpuRequest::SetState(State next_state) { + VLOG(5) << StringPrintf("[%d] SetState old=%d, new=%d.", id_, state_, + next_state); + switch (state_) { + case kUninitialized: + if (next_state == kCreated) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kCreated: + if (next_state == kSubmitted) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kSubmitted: + if (next_state == kActive || next_state == kDone) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kActive: + if (next_state == kDone) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kDone: + break; + } + + // Illegal state transition. + return util::FailedPreconditionError(StringPrintf( + "Invalid state transition. current=%d, next=%d.", state_, next_state)); +} + +const Buffer& SingleTpuRequest::InputBuffer(const std::string& name, + int batch) const { + StdMutexLock lock(&mutex_); + return host_inputs_.at(name)[batch]; +} + +Buffer SingleTpuRequest::OutputBuffer(const std::string& name, + int batch) const { + StdMutexLock lock(&mutex_); + return host_outputs_.at(name)[batch]; +} + +bool SingleTpuRequest::IsBufferAligned(const Buffer& buffer) { + return reinterpret_cast(buffer.ptr()) % alignment_bytes_ == 0; +} + +util::Status SingleTpuRequest::PostProcessOutputBuffers() { + TRACE_SCOPE("SingleTpuRequest::PostProcessOutputBuffers"); + for (const auto& name_and_output : host_outputs_) { + const auto& layer_name = name_and_output.first; + auto user_output_name_and_buffers = user_outputs_.find(layer_name); + if (user_output_name_and_buffers == user_outputs_.end()) { + return util::InternalError( + StringPrintf("Unable to find output layer %s in user outputs map.", + layer_name.c_str())); + } + + const auto& host_output_buffers = name_and_output.second; + auto& user_output_buffers = user_output_name_and_buffers->second; + if (host_output_buffers.size() < user_output_buffers.size()) { + return util::InternalError( + StringPrintf("Found %zu user output buffers which is greater than " + "%zu host output buffers for layer %s.", + user_output_buffers.size(), host_output_buffers.size(), + layer_name.c_str())); + } + + ASSIGN_OR_RETURN(const auto* layer, + executable_reference_.OutputLayer(layer_name)); + + for (int i = 0; i < user_output_buffers.size(); ++i) { + Buffer user_buffer = user_output_buffers[i]; + if (user_buffer.IsDramType() && !user_buffer.IsManagedType()) { + // No support for post-processing of user output buffer allocated + // on device DRAM. + + // TODO -- When the proper test is implemented, use it to + // validate that this output buffer does not in fact need + // post-processing. + continue; + } + // Otherwise, always do post-processing even if tests indicate that + // it is not needed: the post-processing will also synchronize data + // between the runtime-managed (host) and user-provided output buffer. + + Buffer host_buffer = host_output_buffers[i]; + if (host_buffer.IsDramType()) { + TRACE_SCOPE( + "SingleTpuRequest::PostProcessOutputBuffers::DramToHostOutput"); + ASSIGN_OR_RETURN(auto dram_buffer, host_buffer.GetDramBuffer()); + host_buffer = allocator_->MakeBuffer(layer->PaddedSizeBytes()); + RETURN_IF_ERROR(dram_buffer->WriteTo(host_buffer.ptr())); + } + + { + TRACE_SCOPE("SingleTpuRequest::PostProcessOutputBuffers::Relayout"); + RETURN_IF_ERROR(layer->Relayout(user_buffer.ptr(), host_buffer.ptr())); + } + + if (layer->SignedDataType()) { + TRACE_SCOPE( + "SingleTpuRequest::PostProcessOutputBuffers::" + "TransformSignedDataType"); + RETURN_IF_ERROR(layer->TransformSignedDataType(user_buffer)); + } + } + } + + return util::OkStatus(); +} + +Buffer SingleTpuRequest::ScatterInput(const Buffer& input, + const api::LayerInformation* layer) { + // For iterative models, we need to add padding after each iteration. + auto aligned_input = allocator_->MakeBuffer(layer->PaddedSizeBytes()); + auto padded_single_execution_size = + layer->PaddedSizeBytes() / layer->execution_count_per_inference(); + auto actual_single_execution_size = + layer->ActualSizeBytes() / layer->execution_count_per_inference(); + for (int i = 0; i < layer->execution_count_per_inference(); i++) { + memcpy(aligned_input.ptr() + padded_single_execution_size * i, + input.ptr() + actual_single_execution_size * i, + actual_single_execution_size); + } + + return aligned_input; +} + +Buffer SingleTpuRequest::TryCreateDramBuffer(size_t size_bytes) { + auto buffer_or_error = dram_allocator_->AllocateBuffer(size_bytes); + if (buffer_or_error.ok()) { + return Buffer(std::move(buffer_or_error).ValueOrDie()); + } + + LOG(WARNING) << StringPrintf( + "Failed to allocate TPU DRAM buffer of size %zu: ", + size_bytes) + << buffer_or_error.status().message(); + return allocator_->MakeBuffer(size_bytes); +} + +Buffer SingleTpuRequest::CreateActivationBuffer( + const api::LayerInformation* layer, int batches) { + // TODO: We can't use DRAM buffers when also using batching. + // Note that we could have allocated separate per-batch on-chip DRAM buffers + // instead of using host DRAM, but we don't have a clear use case to evaluate + // the power/perf tradeoff. + if (layer->CacheOnDram() && batches == 1) { + return TryCreateDramBuffer(layer->PaddedSizeBytes()); + } else { + return allocator_->MakeBuffer(layer->PaddedSizeBytes() * batches); + } +} + +Buffer SingleTpuRequest::GetOrCreateBatchOutput( + const api::LayerInformation* layer, const std::string& name) { + const auto existing = batch_outputs_.find(name); + if (existing == batch_outputs_.end()) { + auto batch_output = + CreateActivationBuffer(layer, executable_reference_.BatchSize()); + batch_outputs_[name] = batch_output; + return batch_output; + } else { + return existing->second; + } +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/single_tpu_request.h b/driver/single_tpu_request.h new file mode 100644 index 0000000..364adf5 --- /dev/null +++ b/driver/single_tpu_request.h @@ -0,0 +1,253 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_SINGLE_TPU_REQUEST_H_ +#define DARWINN_DRIVER_SINGLE_TPU_REQUEST_H_ + +#include + +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "api/buffer.h" +#include "driver/allocator.h" +#include "driver/device_buffer.h" +#include "driver/device_buffer_mapper.h" +#include "driver/dma_info.h" +#include "driver/dma_info_extractor.h" +#include "driver/executable_util.h" +#include "driver/memory/address_space.h" +#include "driver/memory/dram_allocator.h" +#include "driver/package_registry.h" +#include "driver/request.h" +#include "driver/tpu_request.h" +#include "executable/executable_generated.h" +#include "port/array_slice.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// An single inference request to TPU. This class is thread-safe. +class SingleTpuRequest : public TpuRequest { + public: + // Constructs a request object for executing the given |executable_reference|. + // |done| is the callback function executed when request is complete. + // TODO: Make this constructor private and create request objects + // through a factory function in the driver. + explicit SingleTpuRequest( + int id, const std::shared_ptr parent_request, + const ExecutableReference* executable_reference, Allocator* allocator, + DramAllocator* dram_allocator, + std::unique_ptr device_buffer_mapper, + const DmaInfoExtractor* extractor, uint64 alignment_bytes, + RequestType type); + explicit SingleTpuRequest( + int id, const std::shared_ptr parent_request, + const ExecutableReference* executable_reference, Allocator* allocator, + DramAllocator* dram_allocator, + std::unique_ptr device_buffer_mapper, + const DmaInfoExtractor* extractor, uint64 alignment_bytes, Done done, + RequestType type); + + SingleTpuRequest(SingleTpuRequest&& rhs) = delete; + SingleTpuRequest& operator=(SingleTpuRequest&& rhs) = delete; + SingleTpuRequest(const SingleTpuRequest&) = delete; + SingleTpuRequest& operator=(const SingleTpuRequest&) = delete; + ~SingleTpuRequest() override; + + util::Status SetDone(Done done) LOCKS_EXCLUDED(mutex_) override; + util::Status AddInput(const std::string& name, const Buffer& input) + LOCKS_EXCLUDED(mutex_) override; + util::Status AddOutput(const std::string& name, Buffer output) + LOCKS_EXCLUDED(mutex_) override; + util::Status AddNoopInputs(const std::string& name, int count) + LOCKS_EXCLUDED(mutex_) override; + util::Status AddNoopOutputs(const std::string& name, int count) + LOCKS_EXCLUDED(mutex_) override; + const Buffer& InputBuffer(const std::string& name, int batch) const override + LOCKS_EXCLUDED(mutex_); + Buffer OutputBuffer(const std::string& name, int batch) const override + LOCKS_EXCLUDED(mutex_); + util::Status Validate() LOCKS_EXCLUDED(mutex_) override; + util::Status Prepare() LOCKS_EXCLUDED(mutex_) override; + util::Status Cancel() LOCKS_EXCLUDED(mutex_) override; + + // TODO: The following functions needs to restricted for use + // by the driver only. + util::Status NotifyRequestSubmitted() LOCKS_EXCLUDED(mutex_) override; + util::Status NotifyRequestActive() LOCKS_EXCLUDED(mutex_) override; + util::Status NotifyCompletion(util::Status status) + LOCKS_EXCLUDED(mutex_) override; + + int id() const override { return id_; } + + RequestType type() const override { return type_; } + + int num_instruction_bitstream_chunks() const override { + return executable().instruction_bitstreams()->Length(); + } + + util::StatusOr> GetDmaInfos() const + LOCKS_EXCLUDED(mutex_) override; + + const ExecutableReference& executable_reference() const override { + return executable_reference_; + } + + DeviceBufferMapper* device_buffer_mapper() const { + return device_buffer_mapper_.get(); + } + + private: + // Compute request state. State transitions : + // kUninitialized -> kCreated -> kSubmitted -> kActive -> kDone + // kUninitialized -> kCreated -> kSubmitted -> kDone [if cancelled]. + enum State { + kUninitialized, // Request not initialized. + kCreated, // Request created, but pending issue to DarwiNN. + kSubmitted, // Request submitted and in queue for issuing to DarwiNN. + kActive, // Request issued to DarwiNN, pending results. + kDone, // Request in terminal state. + }; + + // Attempts a state transition to the given state. + util::Status SetState(State next_state) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Validates that we are in the expected state. + util::Status ValidateState(State expected_state) const + SHARED_LOCKS_REQUIRED(mutex_); + + // Maps all data buffers (input, output, parameters) to the DarwiNN address + // space. + util::Status MapDataBuffers() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Map instruction buffers to the DarwiNN address space. + util::Status MapInstructionBuffers() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Unmaps all buffers and frees the allocated instruction and parameter + // buffers if any. Reverse of what is done in #Prepare(). + util::Status Cleanup() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Convenience function that returns the backing executable in + // |executable_reference_|. + const darwinn::Executable& executable() const { + return executable_reference_.executable(); + } + + // Returns true if the alignment requirement for a provided buffer is met. + bool IsBufferAligned(const Buffer& buffer); + + // Post processes the output buffers. This includes: + // 1. Relayout the outputs in host_outputs_ to user-expected layouts and + // store them in the user_outputs_. Some outputs do not need a relayout + // and for those we set the same user-provided buffer in the host_outputs_. + // Those are ignored by this method. + // 2. Perform sign conversion. + util::Status PostProcessOutputBuffers() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Tries to create a TPU DRAM buffer. If it fails, it falls back to create a + // host DRAM buffer. + Buffer TryCreateDramBuffer(size_t size_bytes); + + // Creates and returns a new buffer for |batches| copies of the activations of + // a provided layer. + Buffer CreateActivationBuffer(const api::LayerInformation* layer, + int batches); + + // Gets a contiguous buffer that holds all batched host_outputs_ for a given + // layer. Lazily creates the buffer on first access, but always returns the + // same buffer when called for the same layer. + Buffer GetOrCreateBatchOutput(const api::LayerInformation* layer, + const std::string& name) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Copies a provided input buffer in such a way that inputs of each iteration + // has the alignment requirements. + Buffer ScatterInput(const Buffer& input, const api::LayerInformation* layer); + + // Unique ID for request. + const int id_; + + // Track the type of TPU request for logging purposes. + const RequestType type_; + + // The parent driver request this TpuRequest is a part of. We mostly hold on + // to a shared_pointer to the parent here to make sure it outlives all its + // TPU requests. + const std::shared_ptr parent_request_; + + // Executable for the compute request. + const ExecutableReference& executable_reference_; + + // Buffer allocator. + Allocator* const allocator_; + + // On-Chip DRAM Buffer allocator. + DramAllocator* dram_allocator_; + + // Maps and stores all device buffers. + std::unique_ptr device_buffer_mapper_; + + // DMA info extractor. + const DmaInfoExtractor& extractor_; + + // Maintains integrity of the request object. + mutable std::mutex mutex_; + + // Request state. + State state_ GUARDED_BY(mutex_){kUninitialized}; + + // Infeed and outfeed host buffers. + // host_*[layer_name][batch_id] = buffer. + Buffer::NamedMap host_inputs_ GUARDED_BY(mutex_); + Buffer::NamedMap host_outputs_ GUARDED_BY(mutex_); + + // Cache of batch-sized output buffers that are used to ensure + // host_outputs_ are contiguous. + std::unordered_map batch_outputs_ GUARDED_BY(mutex_); + + // Buffers to contain the user-facing outputs. The difference between user and + // host outputs is that host_outputs have a TPU-friendly layout while user + // outputs have a user-friendly layout. For DMAs and basically anything from + // driver down, we mostly deal host_outputs_. And for anything from driver up + // or methods exposed in the API, we deal with user_outputs_. + Buffer::NamedMap user_outputs_ GUARDED_BY(mutex_); + + // Final request completion callback. + Done done_ GUARDED_BY(mutex_); + + // A copy of the mapped parameters owned by executable reference. + const DeviceBuffer parameter_device_buffer_ GUARDED_BY(mutex_); + // Buffers for instructions. + std::unique_ptr instruction_buffers_ GUARDED_BY(mutex_); + + // The alignment requirement for input and output buffers provided by the + // user. + const uint64 alignment_bytes_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_SINGLE_TPU_REQUEST_H_ diff --git a/driver/time_stamper/BUILD b/driver/time_stamper/BUILD new file mode 100644 index 0000000..86229a5 --- /dev/null +++ b/driver/time_stamper/BUILD @@ -0,0 +1,57 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: time related libraries used by DarwiNN drivers. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +cc_library( + name = "driver_time_stamper", + srcs = ["driver_time_stamper.cc"], + hdrs = ["driver_time_stamper.h"], + deps = [ + ":time_stamper", + "//port", + ], +) + +cc_library( + name = "time_stamper", + hdrs = ["time_stamper.h"], + deps = ["//port"], +) + +cc_library( + name = "driver_time_stamper_factory", + hdrs = ["driver_time_stamper_factory.h"], + deps = [ + ":driver_time_stamper", + ":time_stamper", + ":time_stamper_factory", + "//port", + ], +) + +cc_library( + name = "time_stamper_factory", + hdrs = ["time_stamper_factory.h"], + deps = [ + ":time_stamper", + "//port", + ], +) diff --git a/driver/time_stamper/driver_time_stamper.cc b/driver/time_stamper/driver_time_stamper.cc new file mode 100644 index 0000000..a634d69 --- /dev/null +++ b/driver/time_stamper/driver_time_stamper.cc @@ -0,0 +1,29 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/time_stamper/driver_time_stamper.h" + +#include "port/time.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +int64 DriverTimeStamper::GetTimeNanoSeconds() const { + return GetCurrentTimeNanos(); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/time_stamper/driver_time_stamper.h b/driver/time_stamper/driver_time_stamper.h new file mode 100644 index 0000000..9a437f3 --- /dev/null +++ b/driver/time_stamper/driver_time_stamper.h @@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_H_ +#define DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_H_ + +#include "driver/time_stamper/time_stamper.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Microsec-resolution monotonic clock. +class DriverTimeStamper : public TimeStamper { + public: + DriverTimeStamper() = default; + ~DriverTimeStamper() = default; + + int64 GetTimeNanoSeconds() const override; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_H_ diff --git a/driver/time_stamper/driver_time_stamper_factory.h b/driver/time_stamper/driver_time_stamper_factory.h new file mode 100644 index 0000000..166160f --- /dev/null +++ b/driver/time_stamper/driver_time_stamper_factory.h @@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_FACTORY_H_ +#define DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_FACTORY_H_ + +#include "driver/time_stamper/driver_time_stamper.h" +#include "driver/time_stamper/time_stamper_factory.h" +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Factory class for creating DriverTimeStamper objects. +class DriverTimeStamperFactory : public TimeStamperFactory { + public: + DriverTimeStamperFactory() = default; + virtual ~DriverTimeStamperFactory() = default; + + std::unique_ptr CreateTimeStamper() override { + return gtl::MakeUnique(); + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TIME_STAMPER_DRIVER_TIME_STAMPER_FACTORY_H_ diff --git a/driver/time_stamper/time_stamper.h b/driver/time_stamper/time_stamper.h new file mode 100644 index 0000000..5d83418 --- /dev/null +++ b/driver/time_stamper/time_stamper.h @@ -0,0 +1,73 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_H_ +#define DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Abstract class for timestamping. Make it a class so a stateful mock can be +// used in tests. +// TODO move timestamper to a common location to share between +// driver and driver2. +class TimeStamper { + public: + // Multiplier factors to help convert between various timer resolutions. + static constexpr int64 kNanoSecondsPerMicroSecond = 1000; + static constexpr int64 kNanoSecondsPerMilliSecond = + 1000 * kNanoSecondsPerMicroSecond; + static constexpr int64 kNanoSecondsPerSecond = + 1000 * kNanoSecondsPerMilliSecond; + static constexpr int64 kMicroSecondsPerSecond = + kNanoSecondsPerSecond / kNanoSecondsPerMicroSecond; + static constexpr int64 kMilliSecondsPerSecond = + kNanoSecondsPerSecond / kNanoSecondsPerMilliSecond; + static constexpr int64 kInvalidTimestamp = -1; + + TimeStamper() = default; + + // This class is neither copyable nor movable. + TimeStamper(const TimeStamper &) = delete; + TimeStamper &operator=(const TimeStamper &) = delete; + + virtual ~TimeStamper() = default; + + // Returns a monotonically-increasing timestamp. Base resolution is nano + // second. Default implementations are provided for other resolutions. + // However, if it is not possible to provide nano second resolution, + // implementations can choose to override lower resolution methods explictly. + virtual int64 GetTimeNanoSeconds() const = 0; + + virtual int64 GetTimeMicroSeconds() const { + return GetTimeNanoSeconds() / kNanoSecondsPerMicroSecond; + } + + virtual int64 GetTimeMilliSeconds() const { + return GetTimeNanoSeconds() / kNanoSecondsPerMilliSecond; + } + + virtual int64 GetTimeSeconds() const { + return GetTimeNanoSeconds() / kNanoSecondsPerSecond; + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_H_ diff --git a/driver/time_stamper/time_stamper_factory.h b/driver/time_stamper/time_stamper_factory.h new file mode 100644 index 0000000..e6a4fed --- /dev/null +++ b/driver/time_stamper/time_stamper_factory.h @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_FACTORY_H_ +#define DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_FACTORY_H_ + +#include "driver/time_stamper/time_stamper.h" + +#include "port/ptr_util.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Factory class for allocating TimeStamper objects. +class TimeStamperFactory { + public: + TimeStamperFactory() = default; + + // This class is neither copyable nor movable. + TimeStamperFactory(const TimeStamperFactory &) = delete; + TimeStamperFactory &operator=(const TimeStamperFactory &) = delete; + + virtual ~TimeStamperFactory() = default; + + virtual std::unique_ptr CreateTimeStamper() = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TIME_STAMPER_TIME_STAMPER_FACTORY_H_ diff --git a/driver/top_level_handler.h b/driver/top_level_handler.h new file mode 100644 index 0000000..14bf593 --- /dev/null +++ b/driver/top_level_handler.h @@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TOP_LEVEL_HANDLER_H_ +#define DARWINN_DRIVER_TOP_LEVEL_HANDLER_H_ + +#include "port/status.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Interface for handling resets. +class TopLevelHandler { + public: + virtual ~TopLevelHandler() = default; + + // Opens reset handler. + virtual util::Status Open() { + return util::Status(); // OK + } + + // Closes reset handler. + virtual util::Status Close() { + return util::Status(); // OK + } + + // Quits from reset state. + virtual util::Status QuitReset() { + return util::Status(); // OK + } + + // Goes into reset state. + virtual util::Status EnableReset() { + return util::Status(); // OK + } + + // Enables/disables software clock gating - implementation must be idempotent. + virtual util::Status EnableSoftwareClockGate() { + return util::Status(); // OK + } + virtual util::Status DisableSoftwareClockGate() { + return util::Status(); // OK + } + + // Enables/disables hardware clock gating - implementation must be idempotent. + virtual util::Status EnableHardwareClockGate() { + return util::Status(); // OK + } + virtual util::Status DisableHardwareClockGate() { + return util::Status(); // OK + } +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TOP_LEVEL_HANDLER_H_ diff --git a/driver/tpu_request.h b/driver/tpu_request.h new file mode 100644 index 0000000..45a2dcf --- /dev/null +++ b/driver/tpu_request.h @@ -0,0 +1,115 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_TPU_REQUEST_H_ +#define DARWINN_DRIVER_TPU_REQUEST_H_ + +#include +#include +#include + +#include "api/buffer.h" +#include "api/request.h" +#include "driver/dma_info.h" +#include "driver/package_registry.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// An abstract class to represent an inference request to TPU. +class TpuRequest { + public: + // A type for request completion callback. + // The int argument is the same as return value of id(). + using Done = std::function; + + // Classify each TPU Request for logging. + using RequestType = api::Request::TimingEvent::TpuRequestType; + + TpuRequest() = default; + virtual ~TpuRequest() = default; + + // This class is not copyable nor movable. + TpuRequest(const TpuRequest&) = delete; + TpuRequest& operator=(const TpuRequest&) = delete; + + // Sets the callback function executed when request is complete. + virtual util::Status SetDone(Done done) = 0; + + // Adds an input or output buffer. This may be called repeatedly depending + // on the batch size as long as the request instance is not submitted. If + // supplied "name" does not exist or size constraints on the input and output + // buffers do not match executable, will return failure. Memory backing the + // |Buffer| instance must be valid throughout the life of the request. + virtual util::Status AddInput(const std::string& name, + const Buffer& input) = 0; + virtual util::Status AddOutput(const std::string& name, Buffer output) = 0; + + // Add a provided number of dummy input/output buffers. This is helpful for + // evening out the number of buffers to native batch size. + virtual util::Status AddNoopInputs(const std::string& name, int count) = 0; + virtual util::Status AddNoopOutputs(const std::string& name, int count) = 0; + + // Returns the input and output buffers that the TPU DMAs to. This is only for + // use in reference driver and similar. + virtual const Buffer& InputBuffer(const std::string& name, + int batch) const = 0; + virtual Buffer OutputBuffer(const std::string& name, int batch) const = 0; + + // Validates the constraints. + virtual util::Status Validate() = 0; + + // Prepares the request to be submitted. + virtual util::Status Prepare() = 0; + + // Cancels the pending request. Cancellation is best effort. Completion + // callback is called if not already. Canceling a completed request has + // no effect. + // Note: A single TpuRequest cancelation will not cause an immediate + // cancellation of the parent driver::Request. However, it will guarantee a + // cancellation status once the parent request calls its Done callback. + virtual util::Status Cancel() = 0; + + // Notifies that request is submitted to the driver, but not yet issued. + virtual util::Status NotifyRequestSubmitted() = 0; + + // Notifies that request is active. That is, request is issued to DarwiNN. + virtual util::Status NotifyRequestActive() = 0; + + // Notifies completion of the request with the given status. + virtual util::Status NotifyCompletion(util::Status status) = 0; + + // Returns request id. + virtual int id() const = 0; + + // Returns the TPU request type that is used for logging. + virtual RequestType type() const = 0; + + // Returns the number of instruction bitstream chunks. + virtual int num_instruction_bitstream_chunks() const = 0; + + // Returns a list of DMAs to be performed. + virtual util::StatusOr> GetDmaInfos() const = 0; + + // Returns executable reference. + virtual const ExecutableReference& executable_reference() const = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_TPU_REQUEST_H_ diff --git a/driver/usb/BUILD b/driver/usb/BUILD new file mode 100644 index 0000000..9c266eb --- /dev/null +++ b/driver/usb/BUILD @@ -0,0 +1,228 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# USB driver specific functionality. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "usb_device_interface", + hdrs = ["usb_device_interface.h"], + deps = ["//port"], +) + +LIBUSB_OPTIONS_SRC = select({ + "//:windows": ["libusb_options_windows.cc"], + "//conditions:default": ["libusb_options_default.cc"], +}) + +cc_library( + name = "libusb_options", + srcs = LIBUSB_OPTIONS_SRC, + hdrs = ["libusb_options.h"], + deps = ["//port"] + select({ + "//:windows": ["@libusb//:headers"], + "//conditions:default": ["@libusb//:headers"], + }), +) + +cc_library( + name = "libusb_options_no_external_release", + srcs = LIBUSB_OPTIONS_SRC, + hdrs = ["libusb_options.h"], + deps = [ + "//port", + "//third_party/libusb", # statically linked + ], +) + +# libUSB is dynamically linked in this version. +cc_library( + name = "local_usb_device", + srcs = ["local_usb_device.cc"], + hdrs = ["local_usb_device.h"], + deps = [ + ":libusb_options", + ":usb_device_interface", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ] + select({ + "//:windows": ["@libusb//:headers"], + "//conditions:default": ["@libusb//:headers"], + }), +) + +# libUSB is statically linked in this version. +cc_library( + name = "local_usb_device_no_external_release", + srcs = ["local_usb_device.cc"], + hdrs = ["local_usb_device.h"], + deps = [ + ":libusb_options_no_external_release", + ":usb_device_interface", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + "//third_party/libusb", + ], +) + +cc_library( + name = "usb_standard_commands", + srcs = ["usb_standard_commands.cc"], + hdrs = ["usb_standard_commands.h"], + deps = [ + ":usb_device_interface", + "//port", + "//port:std_mutex_lock", + ], +) + +cc_library( + name = "usb_dfu_commands", + srcs = ["usb_dfu_commands.cc"], + hdrs = ["usb_dfu_commands.h"], + deps = [ + ":usb_device_interface", + ":usb_standard_commands", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + ], +) + +genrule( + name = "usb_latest_firmware", + srcs = [ + "apex_latest_single_ep.bin", + "apex_latest_multi_ep.bin", + ], + outs = [ + "usb_latest_firmware.h", + ], + cmd = """ + echo "namespace {" > $(location usb_latest_firmware.h) + for FILE in $(SRCS); do + FILENAME=$${FILE##*/} + FILEBASE=$${FILENAME%.*} + echo "const unsigned char ""$$FILEBASE"" [] = {" >> $(location usb_latest_firmware.h) + xxd -i < "$$FILE" >> $(location usb_latest_firmware.h) + echo "};" >> $(location usb_latest_firmware.h) + echo "constexpr unsigned int ""$$FILEBASE""_len = sizeof(""$$FILEBASE"")/sizeof(unsigned char);" >> $(location usb_latest_firmware.h) + echo "" >> $(location usb_latest_firmware.h) + done + echo "} // namespace" >> $(location usb_latest_firmware.h) + """, +) + +cc_library( + name = "usb_ml_commands", + srcs = ["usb_ml_commands.cc"], + hdrs = ["usb_ml_commands.h"], + deps = [ + ":usb_device_interface", + ":usb_standard_commands", + "//port", + "//port:std_mutex_lock", + ], +) + +cc_library( + name = "usb_registers", + srcs = ["usb_registers.cc"], + hdrs = ["usb_registers.h"], + deps = [ + ":usb_ml_commands", + "//driver/registers", + "//port", + ], +) + +cc_library( + name = "usb_dfu_util", + srcs = ["usb_dfu_util.cc"], + hdrs = ["usb_dfu_util.h"], + deps = [ + ":usb_device_interface", + ":usb_dfu_commands", + "//port", + "//port:tracing", + ], +) + +cc_library( + name = "usb_io_request", + srcs = ["usb_io_request.cc"], + hdrs = ["usb_io_request.h"], + deps = [ + ":usb_ml_commands", + "//driver:device_buffer", + "//driver:dma_chunker", + "//driver:dma_info", + "//port", + ], +) + +cc_library( + name = "usb_driver", + srcs = [ + "usb_driver.cc", + ":usb_latest_firmware", + ], + hdrs = ["usb_driver.h"], + deps = [ + ":usb_device_interface", + ":usb_dfu_commands", + ":usb_dfu_util", + ":usb_io_request", + ":usb_ml_commands", + ":usb_registers", + "//api:buffer", + "//api:watchdog", + "//driver", + "//driver:allocator", + "//driver:device_buffer", + "//driver:device_buffer_mapper", + "//driver:dma_chunker", + "//driver:dma_info", + "//driver:dma_info_extractor", + "//driver:hardware_structures", + "//driver:package_registry", + "//driver:request", + "//driver:run_controller", + "//driver:single_queue_dma_scheduler", + "//driver:single_tpu_request", + "//driver:top_level_handler", + "//driver:tpu_request", + "//driver/config", + "//driver/interrupt:interrupt_controller_interface", + "//driver/interrupt:top_level_interrupt_manager", + "//driver/memory:address_utilities", + "//driver/memory:dma_direction", + "//driver/memory:dram_allocator", + "//driver/memory:nop_address_space", + "//driver/registers", + "//driver/time_stamper", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//port:tracing", + ], +) diff --git a/driver/usb/apex_latest_multi_ep.bin b/driver/usb/apex_latest_multi_ep.bin new file mode 100755 index 0000000..6894c34 Binary files /dev/null and b/driver/usb/apex_latest_multi_ep.bin differ diff --git a/driver/usb/apex_latest_single_ep.bin b/driver/usb/apex_latest_single_ep.bin new file mode 100755 index 0000000..7c387a9 Binary files /dev/null and b/driver/usb/apex_latest_single_ep.bin differ diff --git a/driver/usb/libusb_options.h b/driver/usb/libusb_options.h new file mode 100644 index 0000000..a1c70e9 --- /dev/null +++ b/driver/usb/libusb_options.h @@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_LIBUSB_OPTIONS_H_ +#define DARWINN_DRIVER_USB_LIBUSB_OPTIONS_H_ + +#include "port/statusor.h" +#if DARWINN_PORT_USE_EXTERNAL +#include "libusb/libusb.h" +#else // !DARWINN_PORT_USE_EXTERNAL +#include +#endif // DARWINN_PORT_USE_EXTERNAL + +namespace platforms { +namespace darwinn { +namespace driver { + +int SetLibUsbOptions(libusb_context* context); + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_LIBUSB_OPTIONS_H_ diff --git a/driver/usb/libusb_options_default.cc b/driver/usb/libusb_options_default.cc new file mode 100644 index 0000000..87cad4a --- /dev/null +++ b/driver/usb/libusb_options_default.cc @@ -0,0 +1,27 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/libusb_options.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// The implementation of SetLibUsbOptions for the default platform +// is a no-op. On other platforms, this may do something of interest. +int SetLibUsbOptions(libusb_context* context) { return LIBUSB_SUCCESS; } + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/libusb_options_windows.cc b/driver/usb/libusb_options_windows.cc new file mode 100644 index 0000000..606a91d --- /dev/null +++ b/driver/usb/libusb_options_windows.cc @@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/libusb_options.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +int SetLibUsbOptions(libusb_context* context) { + // The UsbDk backend makes libusb behave the most like the Linux version. + // Without using UsbDk, manual intervention with Administrator privileges + // is required to have the Windows USB stack detach the DFU configuration + // and reattach the fully-configured device. + auto status = libusb_set_option(context, LIBUSB_OPTION_USE_USBDK); + return status; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/local_usb_device.cc b/driver/usb/local_usb_device.cc new file mode 100644 index 0000000..f60fa5a --- /dev/null +++ b/driver/usb/local_usb_device.cc @@ -0,0 +1,1099 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/local_usb_device.h" + +#include + +#include "driver/usb/libusb_options.h" +#include "driver/usb/usb_device_interface.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/time.h" +#include "port/tracing.h" + +#define VLOG_IF_ERROR(L, S) \ + if (!(S).ok()) { \ + VLOG((L)) << S << " " << __FILE__ << ":" << __LINE__; \ + } + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { + +// Max depth for USB 3 is 7. +constexpr int kMaxUsbPathDepth = 7; +constexpr const char* kUsbPathPrefix = "/sys/bus/usb/devices/"; + +// Automatic retry for control commands, to reduce failure rates. +constexpr int kMaxNumRetriesForCommands = 5; + +// Automatic retries for checking if device is available after Close. +constexpr int kMaxNumRetriesForClose = 3; + +util::Status ConvertLibUsbError(int error, const char* context) { + if (error >= 0) { + return util::Status(); // OK. + } + + std::string logline = StringPrintf("USB error %d [%s]", error, context); + + VLOG(1) << StringPrintf("%s: %s", __func__, logline.c_str()); + + switch (error) { + case LIBUSB_ERROR_INVALID_PARAM: + return util::InvalidArgumentError(logline); + case LIBUSB_ERROR_ACCESS: + return util::PermissionDeniedError(logline); + case LIBUSB_ERROR_NO_MEM: + return util::ResourceExhaustedError(logline); + case LIBUSB_ERROR_NO_DEVICE: + return util::UnavailableError(logline); + case LIBUSB_ERROR_NOT_FOUND: + return util::NotFoundError(logline); + case LIBUSB_ERROR_BUSY: + return util::DeadlineExceededError(logline); + case LIBUSB_ERROR_TIMEOUT: + return util::DeadlineExceededError(logline); + case LIBUSB_ERROR_OVERFLOW: + return util::DataLossError(logline); + case LIBUSB_ERROR_INTERRUPTED: + return util::CancelledError(logline); + case LIBUSB_ERROR_NOT_SUPPORTED: + return util::UnimplementedError(logline); + default: + return util::UnknownError(logline); + } +} + +util::Status ConvertLibUsbTransferStatus(libusb_transfer_status status, + const char* context) { + if (status == LIBUSB_TRANSFER_COMPLETED) { + return util::Status(); // OK. + } + + std::string logline = + StringPrintf("USB transfer error %d [%s]", status, context); + + VLOG(1) << StringPrintf("%s: %s", __func__, logline.c_str()); + + switch (status) { + case LIBUSB_TRANSFER_TIMED_OUT: + return util::DeadlineExceededError(logline); + case LIBUSB_TRANSFER_CANCELLED: + return util::CancelledError(logline); + case LIBUSB_TRANSFER_STALL: + return util::InvalidArgumentError(logline); + case LIBUSB_TRANSFER_NO_DEVICE: + return util::NotFoundError(logline); + case LIBUSB_TRANSFER_OVERFLOW: + return util::DataLossError(logline); + default: + return util::UnknownError(logline); + } +} + +// Automatically retries a libusb command on error. +template +util::Status AutoRetryLibUsbCommand(const LibUsbCommand& func, + const char* context, + int* command_result = nullptr) { + int result = 0; + for (int attempt_count = 0; attempt_count < kMaxNumRetriesForCommands; + ++attempt_count) { + result = func(); + if (result < 0) { + (void)ConvertLibUsbError(result, context); + VLOG(1) << StringPrintf("[%s] failed [%d].", context, attempt_count + 1); + } else { + break; + } + } + + if (command_result) { + *command_result = result; + } + return ConvertLibUsbError(result, context); +} + +// Find if a device exists at a given bus/port combination, with retries. +// Used to detect if a device has finished being released by the OS during +// Close. +util::Status FindDeviceByBusAndPortWithRetries(libusb_context* context, + int bus_number, + int port_number) { + for (int attempt_count = 0; attempt_count < kMaxNumRetriesForClose; + ++attempt_count) { + libusb_device** device_list; + ssize_t device_count = libusb_get_device_list(context, &device_list); + // Clean up the device list when we leave this scope. + auto device_list_cleaner = + MakeCleanup([device_list] { libusb_free_device_list(device_list, 1); }); + bool found = false; + for (int i = 0; i < device_count; i++) { + libusb_device* device = device_list[i]; + int device_bus_number = libusb_get_bus_number(device); + int device_port_number = libusb_get_port_number(device); + if (device_port_number == port_number && + device_bus_number == bus_number) { + found = true; + break; + } + } + if (found) { + return util::Status(); + } else { + Sleep(1); + } + } + + return util::NotFoundError(StringPrintf( + "Could not find device on bus %d and port %d.", bus_number, port_number)); +} + +} // namespace + +LocalUsbDevice::LocalUsbDevice(libusb_device_handle* handle, bool use_zero_copy, + libusb_context* context) + : use_zero_copy_(use_zero_copy), + libusb_handle_(handle), + libusb_context_(context) { + CHECK(handle != nullptr); + CHECK(context != nullptr); + VLOG(10) << __func__; + + libusb_keep_running_ = true; + libusb_event_thread_ = std::thread([this]() NO_THREAD_SAFETY_ANALYSIS { + TRACE_START_THREAD("LocalUsbDeviceEventThread"); + while (libusb_keep_running_) { + libusb_handle_events(libusb_context_); + } + }); +} + +LocalUsbDevice::~LocalUsbDevice() { + VLOG(10) << __func__; + (void)Close(CloseAction::kNoReset); +} + +util::Status LocalUsbDevice::CheckForNullHandle(const char* context) const { + if (libusb_handle_ == nullptr) { + return util::FailedPreconditionError(context); + } + return util::Status(); // OK. +} + +void LocalUsbDevice::TryCancelAllTransfers() { + StdMutexLock lock(&mutex_); + DoCancelAllTransfers(); +} + +void LocalUsbDevice::DoCancelAllTransfers() { + { + StdCondMutexLock cond_lock(&async_callback_mutex_); + // Cancel all async transfer. + VLOG(9) << StringPrintf("%s: cancelling %d async transfers", __func__, + static_cast(async_transfers_.size())); + for (auto transfer_control_block : async_transfers_) { + VLOG_IF_ERROR( + 1, ConvertLibUsbError(libusb_cancel_transfer(transfer_control_block), + __func__)); + } + + VLOG(9) << StringPrintf("%s: waiting for all async transfers to complete", + __func__); + + // Wait for all async transfer to complete. + // This could take some time, as cancel may or may not work. + while (!async_transfers_.empty()) { + cond_.wait(cond_lock); + } + } + + VLOG(9) << StringPrintf("%s: all async transfers have completed", __func__); +} + +// TODO use status update to record the first failure. +util::Status LocalUsbDevice::Close(CloseAction action) { + TRACE_SCOPE("LocalUsbDevice::Close"); + + StdMutexLock lock(&mutex_); + + VLOG(6) << StringPrintf("%s: closing device %p ", __func__, libusb_handle_); + + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Perform forceful reset if so specified. + switch (action) { + case CloseAction::kForcefulPortReset: + case CloseAction::kForcefulChipReset: { + TRACE_SCOPE("LocalUsbDevice::Close:forceful_reset"); + VLOG(1) << StringPrintf("%s: forcefully reset device %p", __func__, + libusb_handle_); + VLOG_IF_ERROR( + 1, ConvertLibUsbError(libusb_reset_device(libusb_handle_), __func__)); + break; + } + default: { + TRACE_SCOPE("LocalUsbDevice::Close:release_interface"); + // Release all interfaces claimed for this handle. + // This is not needed if we reset the device upfront. + for (auto interface_id : claimed_interfaces_) { + VLOG(9) << StringPrintf("%s: releasing claimed interface %d", __func__, + interface_id); + VLOG_IF_ERROR(1, ConvertLibUsbError(libusb_release_interface( + libusb_handle_, interface_id), + __func__)); + } + break; + } + } + + DoCancelAllTransfers(); + + // Release all transfer buffers allocated for this handle. + VLOG(9) << StringPrintf("%s: releasing %d transfer buffers", __func__, + static_cast(transfer_buffers_.size())); + + for (auto& record : transfer_buffers_) { + VLOG_IF_ERROR(1, DoReleaseTransferBuffer(record.second)); + } + transfer_buffers_.clear(); + + // Perform graceful reset if so specified. + switch (action) { + case CloseAction::kGracefulPortReset: + case CloseAction::kGracefulChipReset: { + TRACE_SCOPE("LocalUsbDevice::Close:graceful_reset"); + VLOG(9) << StringPrintf("%s: performing graceful reset", __func__); + VLOG_IF_ERROR( + 1, ConvertLibUsbError(libusb_reset_device(libusb_handle_), __func__)); + break; + } + default: { + // Do nothing. + break; + } + } + + libusb_keep_running_ = false; + + // Get the libusb bus/port number before closing. + int this_bus_number, this_port_number; + libusb_device* this_dev = libusb_get_device(libusb_handle_); + this_bus_number = libusb_get_bus_number(this_dev); + this_port_number = libusb_get_port_number(this_dev); + + // Close the libusb device handle. Event thread is awaken by this + // action. + libusb_close(libusb_handle_); + libusb_handle_ = nullptr; + libusb_event_thread_.join(); + + // Block until the closed device reappears on the USB bus (or for + // kMaxNumRetriesForClose checks). + VLOG_IF_ERROR(1, FindDeviceByBusAndPortWithRetries( + libusb_context_, this_bus_number, this_port_number)); + + libusb_exit(libusb_context_); + libusb_context_ = nullptr; + + VLOG(9) << StringPrintf("%s: final clean up completed", __func__); + + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::SetConfiguration(int configuration) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Claimed interfaces for current configuration should already be released + // before we change configuration. + if (!claimed_interfaces_.empty()) { + VLOG(1) << StringPrintf("%s Claimed interfaces have not been released", + __func__); + claimed_interfaces_.clear(); + } + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + return AutoRetryLibUsbCommand( + [=] { + return libusb_set_configuration(libusb_handle_alias, configuration); + }, + __func__); +} + +util::Status LocalUsbDevice::ClaimInterface(int interface_number) { + TRACE_SCOPE("LocalUsbDevice::ClaimInterface"); + VLOG(10) << __func__; + + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + RETURN_IF_ERROR(AutoRetryLibUsbCommand( + [=] { + return libusb_claim_interface(libusb_handle_alias, interface_number); + }, + __func__)); + + claimed_interfaces_.insert(interface_number); + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::ReleaseInterface(int interface_number) { + TRACE_SCOPE("LocalUsbDevice::ReleaseInterface"); + VLOG(10) << __func__; + + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + auto iterator = claimed_interfaces_.find(interface_number); + if (iterator != claimed_interfaces_.end()) { + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + RETURN_IF_ERROR(AutoRetryLibUsbCommand( + [=] { + return libusb_release_interface(libusb_handle_alias, + interface_number); + }, + __func__)); + + claimed_interfaces_.erase(iterator); + return util::Status(); // OK. + } + return util::NotFoundError(__func__); +} + +util::Status LocalUsbDevice::GetDescriptor(DescriptorType desc_type, + uint8_t desc_index, + MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + int result = 0; + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + RETURN_IF_ERROR(AutoRetryLibUsbCommand( + [=] { + // data_in is shallow copied into this lambda. + return libusb_get_descriptor( + libusb_handle_alias, static_cast(desc_type), desc_index, + data_in.data(), static_cast(data_in.size())); + }, + context, &result)); + + *num_bytes_transferred = static_cast(result); + return util::Status(); // OK. +} + +UsbDeviceInterface::DeviceSpeed LocalUsbDevice::GetDeviceSpeed() const { + StdMutexLock lock(&mutex_); + if (!CheckForNullHandle(__func__).ok()) { + return UsbDeviceInterface::DeviceSpeed::kUnknown; + } + + const int result = libusb_get_device_speed(libusb_get_device(libusb_handle_)); + + switch (result) { + case LIBUSB_SPEED_LOW: + return UsbDeviceInterface::DeviceSpeed::kLow; + case LIBUSB_SPEED_FULL: + return UsbDeviceInterface::DeviceSpeed::kFull; + case LIBUSB_SPEED_HIGH: + return UsbDeviceInterface::DeviceSpeed::kHigh; + case LIBUSB_SPEED_SUPER: + return UsbDeviceInterface::DeviceSpeed::kSuper; + case LIBUSB_SPEED_UNKNOWN: + default: + return UsbDeviceInterface::DeviceSpeed::kUnknown; + } +} + +util::Status LocalUsbDevice::SendControlCommand(const SetupPacket& command, + TimeoutMillis timeout_msec, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Length must be 0. + if (command.length != 0) { + return util::InvalidArgumentError("Length must be 0"); + } + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + + return AutoRetryLibUsbCommand( + [=] { + int result = libusb_control_transfer( + libusb_handle_alias, command.request_type, command.request, + command.value, command.index, nullptr, 0, timeout_msec); + + // Only 0 is the right answer here. + return (result > 0) ? LIBUSB_ERROR_OVERFLOW : result; + }, + __func__); +} + +util::Status LocalUsbDevice::SendControlCommandWithDataOut( + const SetupPacket& command, ConstBuffer data_out, + TimeoutMillis timeout_msec, const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Length must be less than or equal to buffer size. + CHECK_LE(command.length, data_out.length()); + + VLOG(10) << "SYNC CTRL WITH DATA OUT begin"; + + int result = 0; + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + + RETURN_IF_ERROR(AutoRetryLibUsbCommand( + [=] { + return libusb_control_transfer( + libusb_handle_alias, command.request_type, command.request, + command.value, command.index, const_cast(data_out.data()), + command.length, timeout_msec); + }, + context, &result)); + + VLOG(10) << "SYNC CTRL WITH DATA OUT end"; + + // Result must be less than or equal to specified data amount. + CHECK_LE(result, command.length); + + if (result != command.length) { + return util::DataLossError(__func__); + } + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::SendControlCommandWithDataIn( + const SetupPacket& command, MutableBuffer data_in, + size_t* num_bytes_transferred, TimeoutMillis timeout_msec, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Length must be less than or equal to buffer size. + CHECK_LE(command.length, data_in.length()); + + VLOG(10) << "SYNC CTRL WITH DATA IN begin"; + + int result = 0; + + // This alias is created to circumvent thread safety analysis. + // Accessing to libusb_handle_ must be protected by a mutex. + auto* libusb_handle_alias = libusb_handle_; + + RETURN_IF_ERROR(AutoRetryLibUsbCommand( + [=] { + return libusb_control_transfer( + libusb_handle_alias, command.request_type, command.request, + command.value, command.index, data_in.data(), command.length, + timeout_msec); + }, + context, &result)); + + VLOG(10) << "SYNC CTRL WITH DATA IN end"; + + // Result must be less than or equal to specified data amount. + CHECK_LE(result, command.length); + + *num_bytes_transferred = static_cast(result); + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::BulkOutTransfer(uint8_t endpoint, + ConstBuffer data_out, + TimeoutMillis timeout_msec, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + int amount_transferred = 0; + + VLOG(10) << StringPrintf("SYNC OUT %d begin", endpoint); + + const int result = libusb_bulk_transfer( + libusb_handle_, endpoint | LIBUSB_ENDPOINT_OUT, + const_cast(data_out.data()), data_out.length(), + &amount_transferred, timeout_msec); + + VLOG(10) << StringPrintf("SYNC OUT %d end", endpoint); + + if (result < 0) { + return ConvertLibUsbError(result, __func__); + } else { + // Underrun is a fatal error. + CHECK_LE(static_cast(amount_transferred), data_out.length()); + + if (static_cast(amount_transferred) != data_out.length()) { + return util::DataLossError(__func__); + } + } + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::BulkInTransfer(uint8_t endpoint, + MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + int amount_transferred = 0; + *num_bytes_transferred = 0; + + VLOG(10) << StringPrintf("SYNC IN %d begin", endpoint); + + const int result = libusb_bulk_transfer( + libusb_handle_, endpoint | LIBUSB_ENDPOINT_IN, data_in.data(), + data_in.length(), &amount_transferred, timeout_msec); + + VLOG(10) << StringPrintf("SYNC IN %d end", endpoint); + + *num_bytes_transferred = static_cast(amount_transferred); + + if (result < 0) { + return ConvertLibUsbError(result, __func__); + } else { + // Overflow is a fatal error. + CHECK_LE(*num_bytes_transferred, data_in.length()); + } + + return util::Status(); // OK. +} + +util::Status LocalUsbDevice::InterruptInTransfer(uint8_t endpoint, + MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + int amount_transferred = 0; + *num_bytes_transferred = 0; + + VLOG(10) << StringPrintf("SYNC IN %d begin", endpoint); + + const int result = libusb_interrupt_transfer( + libusb_handle_, endpoint | LIBUSB_ENDPOINT_IN, data_in.data(), + data_in.length(), &amount_transferred, timeout_msec); + + VLOG(10) << StringPrintf("SYNC IN %d end", endpoint); + + *num_bytes_transferred = static_cast(amount_transferred); + + if (result < 0) { + return ConvertLibUsbError(result, __func__); + } else { + // Overflow is a fatal error. + CHECK_LE(*num_bytes_transferred, data_in.length()); + } + + return util::Status(); // OK. +} + +void LocalUsbDevice::UnregisterCompletedTransfer(libusb_transfer* transfer) { + VLOG(10) << __func__; + StdMutexLock lock(&async_callback_mutex_); + + // There must be exactly one element to be erased. + CHECK_EQ(async_transfers_.erase(transfer), 1); + + // Notify all, mostly just the main thread which is trying to close this + // device, that status has changed. + cond_.notify_all(); +} + +void LocalUsbDevice::LibUsbDataOutCallback(libusb_transfer* transfer) { + AsyncDataOutUserData* callback_obj = + static_cast(transfer->user_data); + + VLOG(10) << StringPrintf("ASYNC OUT %d end", transfer->endpoint); + + // The callback function is delivered without locking the host interface. + // This allows further calls to be made during callback. + (callback_obj->callback)( + ConvertLibUsbTransferStatus(transfer->status, __func__)); + + callback_obj->device->UnregisterCompletedTransfer(transfer); + delete callback_obj; +} + +void LocalUsbDevice::LibUsbDataInCallback(libusb_transfer* transfer) { + AsyncDataInUserData* callback_obj = + reinterpret_cast(transfer->user_data); + + VLOG(10) << StringPrintf("ASYNC IN %d end", transfer->endpoint & 0x7F); + + // The callback function is delivered without locking the host interface. + // This allows further calls to be made during callback. + (callback_obj->callback)( + ConvertLibUsbTransferStatus(transfer->status, __func__), + static_cast(transfer->actual_length)); + + callback_obj->device->UnregisterCompletedTransfer(transfer); + delete callback_obj; +} + +libusb_transfer* LocalUsbDevice::NewAsyncTransfer() { + // Allocate transfer control block. + libusb_transfer* transfer_control = + libusb_alloc_transfer(kLibUsbTransferNoIsoPackets); + CHECK(transfer_control != nullptr); + + StdMutexLock lock(&async_callback_mutex_); + async_transfers_.insert(transfer_control); + + return transfer_control; +} + +void LocalUsbDevice::DestroyFailedAsyncTransfer( + libusb_transfer* transfer_control) { + StdMutexLock lock(&async_callback_mutex_); + async_transfers_.erase(transfer_control); + libusb_free_transfer(transfer_control); +} + +util::Status LocalUsbDevice::AsyncBulkOutTransfer(uint8_t endpoint, + ConstBuffer data_out, + TimeoutMillis timeout_msec, + DataOutDone callback, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Allocate transfer control block. + libusb_transfer* transfer_control = NewAsyncTransfer(); + + // Allocate user data for completion callback. + // This object is passed to transfer completion callback and will be released + // there. + AsyncDataOutUserData* callback_obj = + new AsyncDataOutUserData{this, std::move(callback)}; + CHECK(callback_obj != nullptr); + + VLOG(10) << StringPrintf("ASYNC OUT %d begin", endpoint); + + libusb_fill_bulk_transfer( + transfer_control, libusb_handle_, endpoint | LIBUSB_ENDPOINT_OUT, + const_cast(data_out.data()), data_out.length(), + LibUsbDataOutCallback, callback_obj, timeout_msec); + + transfer_control->flags |= + (LIBUSB_TRANSFER_SHORT_NOT_OK | LIBUSB_TRANSFER_FREE_TRANSFER); + + util::Status status = + ConvertLibUsbError(libusb_submit_transfer(transfer_control), __func__); + + if (!status.ok()) { + DestroyFailedAsyncTransfer(transfer_control); + delete callback_obj; + } + + return status; +} + +util::Status LocalUsbDevice::AsyncBulkInTransfer(uint8_t endpoint, + MutableBuffer data_in, + TimeoutMillis timeout_msec, + DataInDone callback, + const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Allocate transfer control block. + libusb_transfer* transfer_control = NewAsyncTransfer(); + + // Allocate user data for completion callback. + AsyncDataInUserData* callback_obj = + new AsyncDataInUserData{this, std::move(callback)}; + CHECK(callback_obj != nullptr); + + VLOG(10) << StringPrintf("ASYNC IN %d begin", endpoint & 0x7F); + + libusb_fill_bulk_transfer(transfer_control, libusb_handle_, + endpoint | LIBUSB_ENDPOINT_IN, data_in.data(), + data_in.length(), LibUsbDataInCallback, + callback_obj, timeout_msec); + + transfer_control->flags |= LIBUSB_TRANSFER_FREE_TRANSFER; + + util::Status status = + ConvertLibUsbError(libusb_submit_transfer(transfer_control), __func__); + + if (!status.ok()) { + DestroyFailedAsyncTransfer(transfer_control); + delete callback_obj; + } + return status; +} + +util::Status LocalUsbDevice::AsyncInterruptInTransfer( + uint8_t endpoint, MutableBuffer data_in, TimeoutMillis timeout_msec, + DataInDone callback, const char* context) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + // Allocate transfer control block. + libusb_transfer* transfer_control = NewAsyncTransfer(); + + // Allocate user data for completion callback. + AsyncDataInUserData* callback_obj = + new AsyncDataInUserData{this, std::move(callback)}; + CHECK(callback_obj != nullptr); + + VLOG(10) << StringPrintf("ASYNC IN %d begin", endpoint & 0x7F); + + libusb_fill_interrupt_transfer(transfer_control, libusb_handle_, + endpoint | LIBUSB_ENDPOINT_IN, data_in.data(), + data_in.length(), LibUsbDataInCallback, + callback_obj, timeout_msec); + + transfer_control->flags |= LIBUSB_TRANSFER_FREE_TRANSFER; + + util::Status status = + ConvertLibUsbError(libusb_submit_transfer(transfer_control), __func__); + + if (!status.ok()) { + DestroyFailedAsyncTransfer(transfer_control); + delete callback_obj; + } + return status; +} + +util::StatusOr +LocalUsbDevice::AllocateTransferBuffer(size_t buffer_size) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + uint8_t* ptr = DoAllocateTransferBuffer(buffer_size); + if (ptr == nullptr) { + return util::ResourceExhaustedError(__func__); + } + + auto result = transfer_buffers_.insert( + std::make_pair(ptr, MutableBuffer(ptr, buffer_size))); + + return result.first->second; +} + +uint8_t* LocalUsbDevice::DoAllocateTransferBuffer(size_t buffer_size) { +#if LIBUSB_HAS_MEM_ALLOC + if (use_zero_copy_) { + // Release memory block through libusb and return from here. + return libusb_dev_mem_alloc(libusb_handle_, buffer_size); + } +#endif // LIBUSB_HAS_MEM_ALLOC + + return new uint8_t[buffer_size]; +} + +util::Status LocalUsbDevice::ReleaseTransferBuffer(MutableBuffer buffer) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + + RETURN_IF_ERROR(CheckForNullHandle(__func__)); + + auto block = transfer_buffers_.find(buffer.data()); + + // Missing record of the memory buffer is a fatal error. + CHECK(block != transfer_buffers_.end()); + + // Remove record before we actually release the memory. + // Iterator is invalidated after this line. + transfer_buffers_.erase(block); + + return DoReleaseTransferBuffer(buffer); +} + +util::Status LocalUsbDevice::DoReleaseTransferBuffer(MutableBuffer buffer) { +#if LIBUSB_HAS_MEM_ALLOC + if (use_zero_copy_) { + // Release memory block through libusb and return from here. + return ConvertLibUsbError( + libusb_dev_mem_free(libusb_handle_, buffer.data(), buffer.length()), + __func__); + } +#endif // LIBUSB_HAS_MEM_ALLOC + + // Use plain old delete [] to release the memory block. + delete[] buffer.data(); + + return util::Status(); // OK. +} + +LocalUsbDeviceFactory::LocalUsbDeviceFactory(bool use_zero_copy) + : use_zero_copy_(use_zero_copy) {} + +util::StatusOr +LocalUsbDeviceFactory::ParsePathString(const std::string& path) { + ParsedPath result; + unsigned int bus_number; + const size_t prefix_length = strlen(kUsbPathPrefix); + if (path.length() <= prefix_length) { + return util::InvalidArgumentError( + "Path must be longer than the proper prefix"); + } + + std::stringstream path_stingstream(path.substr(prefix_length)); + + path_stingstream >> bus_number; + if (path_stingstream.fail()) { + return util::InvalidArgumentError("Path must begin with bus number"); + } + if (path_stingstream.peek() == '-') { + path_stingstream.ignore(); + } else { + return util::InvalidArgumentError("Missing separator after bus number"); + } + // TODO: check for valid bus number range. + result.bus_number = static_cast(bus_number); + + unsigned int port; + while (path_stingstream >> port) { + if (path_stingstream.fail()) { + return util::InvalidArgumentError("Path must contain port numbers"); + } + + // TODO: check for valid port number range. + result.port_numbers.push_back(static_cast(port)); + + if (path_stingstream.peek() == '.') { + path_stingstream.ignore(); + } + } + + return result; +} + +std::string LocalUsbDeviceFactory::ComposePathString( + const LocalUsbDeviceFactory::ParsedPath& path) { + std::stringstream result; + result << kUsbPathPrefix; + result << static_cast(path.bus_number); + bool is_first_port = true; + for (uint8 port_number : path.port_numbers) { + if (is_first_port) { + result << '-'; + is_first_port = false; + } else { + result << '.'; + } + result << static_cast(port_number); + } + return result.str(); +} + +util::StatusOr> +LocalUsbDeviceFactory::EnumerateDevices(uint16_t vendor_id, + uint16_t product_id) { + TRACE_SCOPE("LocalUsbDeviceFactory::EnumerateDevices"); + VLOG(6) << StringPrintf("%s: vendor:0x%x, product:0x%x", __func__, vendor_id, + product_id); + + libusb_context* context = nullptr; + const int libusb_init_error = libusb_init(&context); + if (libusb_init_error != 0) { + return util::FailedPreconditionError("libusb initialization failed"); + } + + util::Status libusb_option_status = + ConvertLibUsbError(SetLibUsbOptions(context), "SetLibUsbOptions"); + RETURN_IF_ERROR(libusb_option_status); + + auto context_cleaner = MakeCleanup([context] { libusb_exit(context); }); + + // Find the specified devices + libusb_device** device_list = nullptr; + ssize_t num_device_or_error = libusb_get_device_list(context, &device_list); + + if (num_device_or_error < 0) { + return ConvertLibUsbError(num_device_or_error, __func__); + } + auto device_list_cleaner = MakeCleanup([device_list] { + // The device list must be freed in the end. + // Remove one reference from all devices in the device list. + libusb_free_device_list(device_list, 1); + }); + + std::vector device_paths; + + for (ssize_t device_index = 0; device_index < num_device_or_error; + ++device_index) { + libusb_device* device = device_list[device_index]; + libusb_device_descriptor device_descriptor = {0}; + const uint8 bus_number = libusb_get_bus_number(device); + VLOG(7) << StringPrintf("%s: checking bus[%d] port[%d]", __func__, + bus_number, libusb_get_port_number(device)); + if (libusb_get_device_descriptor(device, &device_descriptor) == + LIBUSB_SUCCESS) { + if ((device_descriptor.idVendor == vendor_id) && + (device_descriptor.idProduct == product_id)) { + // Generate path string for this device. + uint8 port_numbers[kMaxUsbPathDepth] = {0}; + const int depth_or_error = + libusb_get_port_numbers(device, port_numbers, kMaxUsbPathDepth); + if (depth_or_error < 0) { + VLOG(2) << StringPrintf("%s: get device port numbers failed:", + __func__) + << ConvertLibUsbError(depth_or_error, __func__); + } else { + ParsedPath parsed_path = { + bus_number, + std::vector(port_numbers, port_numbers + depth_or_error)}; + std::string path = ComposePathString(parsed_path); + VLOG(2) << StringPrintf("%s: found [%s]", __func__, path.c_str()); + device_paths.push_back(path); + } + } + } else { + VLOG(2) << StringPrintf("%s: get device descriptor failed", __func__); + } + } + + return device_paths; +} + +util::StatusOr> +LocalUsbDeviceFactory::OpenDevice(const std::string& path, TimeoutMillis) { + TRACE_SCOPE("LocalUsbDeviceFactory::OpenDevice"); + VLOG(6) << StringPrintf("%s: [%s]", __func__, path.c_str()); + + ASSIGN_OR_RETURN(auto parsed_path, ParsePathString(path)); + + libusb_context* context = nullptr; + const int libusb_init_error = libusb_init(&context); + if (libusb_init_error != 0) { + return util::FailedPreconditionError("libusb initialization failed"); + } + + util::Status libusb_option_status = + ConvertLibUsbError(SetLibUsbOptions(context), "SetLibUsbOptions"); + if (!libusb_option_status.ok()) { + return libusb_option_status; + } + + auto context_cleaner = MakeCleanup([context] { libusb_exit(context); }); + + // Find the specified devices + libusb_device** device_list = nullptr; + libusb_device* found_device = nullptr; + ssize_t num_device_or_error = libusb_get_device_list(context, &device_list); + + if (num_device_or_error < 0) { + return ConvertLibUsbError(num_device_or_error, __func__); + } + auto device_list_cleaner = MakeCleanup([device_list] { + // The device list must be freed in the end. + // Remove one reference from all devices in the device list. + libusb_free_device_list(device_list, 1); + }); + + for (ssize_t device_index = 0; device_index < num_device_or_error; + ++device_index) { + libusb_device* device = device_list[device_index]; + + const uint8 bus_number = libusb_get_bus_number(device); + VLOG(7) << StringPrintf("%s: checking bus[%d] port[%d]", __func__, + bus_number, libusb_get_port_number(device)); + + if (bus_number != parsed_path.bus_number) { + continue; + } + + // Generate path string for this device. + uint8 port_numbers[kMaxUsbPathDepth] = {0}; + const int depth_or_error = + libusb_get_port_numbers(device, port_numbers, kMaxUsbPathDepth); + if (depth_or_error < 0) { + VLOG(2) << StringPrintf("%s: get device port numbers failed:", __func__) + << ConvertLibUsbError(depth_or_error, __func__); + } else if (depth_or_error == parsed_path.port_numbers.size()) { + if (memcmp(port_numbers, parsed_path.port_numbers.data(), + parsed_path.port_numbers.size()) == 0) { + found_device = device; + break; + } + } + } + + libusb_device_handle* libusb_handle = nullptr; + if (found_device) { + RETURN_IF_ERROR(ConvertLibUsbError( + libusb_open(found_device, &libusb_handle), __func__)); + } else { + return util::NotFoundError(__func__); + } + + VLOG(6) << StringPrintf("%s: device opened %p", __func__, libusb_handle); + + std::unique_ptr device = gtl::WrapUnique( + new LocalUsbDevice(libusb_handle, use_zero_copy_, context)); + + CHECK(device); + + // Ownership of the libusb context has been transferred to the new device. + context_cleaner.release(); + + // This statement explicitly constructs an unique_ptr, instead of relying on + // implicit compiler-invoked conversion. Some C++ 11 compilers/verions do + // not properly invoke the most suitable conversion. + return {std::move(device)}; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/local_usb_device.h b/driver/usb/local_usb_device.h new file mode 100644 index 0000000..8638f5f --- /dev/null +++ b/driver/usb/local_usb_device.h @@ -0,0 +1,272 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_LOCAL_USB_DEVICE_H_ +#define DARWINN_DRIVER_USB_LOCAL_USB_DEVICE_H_ + +#include // NOLINT +#include // NOLINT +#include +#include // NOLINT +#include // NOLINT +#include + +#include "driver/usb/usb_device_interface.h" +#include "port/array_slice.h" +#include "port/defs.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" +#if DARWINN_PORT_USE_EXTERNAL +#include "libusb/libusb.h" +#else // !DARWINN_PORT_USE_EXTERNAL +#include +#endif // DARWINN_PORT_USE_EXTERNAL + +// TODO: upgrade this to allow the latest version of libusb to be used. +// We need at least 1.0.21 (LIBUSB_API_VERSION >= 0x01000105) for zero copy +// feature. + +namespace platforms { +namespace darwinn { +namespace driver { + +// Thread-safe implementation of UsbDeviceInterface on top of libusb. +// +// Async USB transfer functions are still async, but all other +// functions are serialized. +class LocalUsbDevice : public UsbDeviceInterface { + public: + // This class is neither copyable nor movable. + LocalUsbDevice(const LocalUsbDevice&) = delete; + LocalUsbDevice& operator=(const LocalUsbDevice&) = delete; + + // Destructor. Calls Close implicitly. + ~LocalUsbDevice() override; + + util::Status Close(CloseAction action) override LOCKS_EXCLUDED(mutex_); + + util::Status SetConfiguration(int configuration) override + LOCKS_EXCLUDED(mutex_); + + util::Status ClaimInterface(int interface_number) override + LOCKS_EXCLUDED(mutex_); + + util::Status ReleaseInterface(int interface_number) override + LOCKS_EXCLUDED(mutex_); + + util::Status GetDescriptor(DescriptorType desc_type, uint8_t desc_index, + MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + DeviceSpeed GetDeviceSpeed() const override LOCKS_EXCLUDED(mutex_); + + util::Status SendControlCommand(const SetupPacket& command, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status SendControlCommandWithDataOut(const SetupPacket& command, + ConstBuffer data_out, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status SendControlCommandWithDataIn(const SetupPacket& command, + MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status BulkOutTransfer(uint8_t endpoint, ConstBuffer data_out, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status BulkInTransfer(uint8_t endpoint, MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status InterruptInTransfer(uint8_t endpoint, MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status AsyncBulkOutTransfer(uint8_t endpoint, ConstBuffer data_out, + TimeoutMillis timeout_msec, + DataOutDone callback, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status AsyncBulkInTransfer(uint8_t endpoint, MutableBuffer data_in, + TimeoutMillis timeout_msec, + DataInDone callback, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + util::Status AsyncInterruptInTransfer(uint8_t endpoint, MutableBuffer data_in, + TimeoutMillis timeout_msec, + DataInDone callback, + const char* context) override + LOCKS_EXCLUDED(mutex_); + + void TryCancelAllTransfers() override LOCKS_EXCLUDED(mutex_); + + util::StatusOr AllocateTransferBuffer( + size_t buffer_size) override LOCKS_EXCLUDED(mutex_); + + util::Status ReleaseTransferBuffer(MutableBuffer buffer) override + LOCKS_EXCLUDED(mutex_); + + private: + friend class LocalUsbDeviceFactory; + + // User data carried in libusb transfer completion callback. + struct AsyncDataOutUserData { + // Pointer to device object. + LocalUsbDevice* device; + // Pointer to callback function object. + DataOutDone callback; + }; + + // User data carried in libusb transfer completion callback. + struct AsyncDataInUserData { + // Pointer to device object. + LocalUsbDevice* device; + // Pointer to callback function object. + DataInDone callback; + }; + + // Constructor. All instances of this class must be allocated through + // LocalUsbManager. + LocalUsbDevice(libusb_device_handle* handle, bool use_zero_copy, + libusb_context* context); + + // Callback function provided to libubs for data out completion callback. + static void LibUsbDataOutCallback(libusb_transfer* transfer); + + // Callback function provided to libubs for data in completion callback. + static void LibUsbDataInCallback(libusb_transfer* transfer); + + util::Status CheckForNullHandle(const char* context) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + void UnregisterCompletedTransfer(libusb_transfer* transfer) + LOCKS_EXCLUDED(mutex_); + + // Allocates transfer buffer for this device. + // Although this function doesn't explicitly modify any shared data, the + // underlying data in a libusb device handle could be affected. + uint8_t* DoAllocateTransferBuffer(size_t buffer_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Releases transfer buffer previous allocated. This function relies on caller + // to perform bookkeeping of records. Although this function doesn't + // explicitly modify any shared data, the underlying data in a libusb device + // handle could be affected. + util::Status DoReleaseTransferBuffer(MutableBuffer buffer) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + libusb_transfer* NewAsyncTransfer(); + void DestroyFailedAsyncTransfer(libusb_transfer* transfer_control); + + // Cancels all async transfers without explicitly locking the mutex. + void DoCancelAllTransfers() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + static constexpr int kLibUsbTransferNoIsoPackets = 0; + + // Serializes access to this interface and hence shared data. + mutable std::mutex mutex_; + + // Wait till all async transfers complete. + mutable std::condition_variable cond_; + + // True if the implementation should use libusb_dev_mem functions to allocate + // memory. + const bool use_zero_copy_; + + libusb_device_handle* libusb_handle_ GUARDED_BY(mutex_); + + // Interfaces, in current device configuration, have been claimed through + // ClaimInterface. + std::unordered_set claimed_interfaces_ GUARDED_BY(mutex_); + + // Memory buffers allocated through libusb_dev_mem_alloc. + std::map transfer_buffers_ GUARDED_BY(mutex_); + + // Serializes access to this interface and hence shared data. + mutable std::mutex async_callback_mutex_; + + // Transfer control blocks allocated for async USB transfers. + std::unordered_set async_transfers_ + GUARDED_BY(async_callback_mutex_); + + // Points to session context, which is allocated by libusb. + libusb_context* libusb_context_{nullptr}; + + // False if libusb event thread should stop running. + std::atomic libusb_keep_running_{false}; + + // Thread running the libusb event loop. + std::thread libusb_event_thread_; +}; + +class LocalUsbDeviceFactory : public UsbDeviceFactory { + public: + // Holds components to a path string, pointing to a locally connected USB + // device. + struct ParsedPath { + uint8 bus_number; + std::vector port_numbers; + }; + + LocalUsbDeviceFactory(bool use_zero_copy = false); + + ~LocalUsbDeviceFactory() override = default; + + // This class is neither copyable nor movable. + LocalUsbDeviceFactory(const LocalUsbDeviceFactory&) = delete; + LocalUsbDeviceFactory& operator=(const LocalUsbDeviceFactory&) = delete; + + util::StatusOr> EnumerateDevices( + uint16_t vendor_id, uint16_t product_id) override; + + util::StatusOr> OpenDevice( + const std::string& path, TimeoutMillis timeout_msec) override; + + // Visible for testing. + // Returns a path broken down to components. + static util::StatusOr ParsePathString(const std::string& path); + + // Visible for testing. + // Composes a path string from components. + static std::string ComposePathString(const ParsedPath& path); + + private: + // True if we should try to use memory allocation routine provided by libusb, + // for zero copy support. + const bool use_zero_copy_{false}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_LOCAL_USB_DEVICE_H_ diff --git a/driver/usb/usb_device_interface.h b/driver/usb/usb_device_interface.h new file mode 100644 index 0000000..ba69e5c --- /dev/null +++ b/driver/usb/usb_device_interface.h @@ -0,0 +1,400 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_DEVICE_INTERFACE_H_ +#define DARWINN_DRIVER_USB_USB_DEVICE_INTERFACE_H_ + +#include "port/array_slice.h" +#include "port/integral_types.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// This interface abstracts away access to a USB device. +// All operations listed here are primitive, and hence can be trust to be +// atomic/thread-safe in implementation. However, combinations of these +// primitive operations might have to be protected at higher level. +class UsbDeviceInterface { + public: + // Configuration number to be used in set configuration command. + enum ConfigurationNumber { + kFirstDeviceConfiguration = 1, + kResetDeviceConfiguration = -1, + }; + + // Device class, as defined in USB spec. + enum class DeviceClass { + // Use class information in the Interface descriptors. + kPerInterface = 0, + // Vendor-specific class. + kVendorSpecific = 0xff, + }; + + enum class DeviceSpeed { + // The implementation, probably a remote USB device, + // doesn't know the current USB speed. + kUnknown = 0, + + // Low speed is 1.5 Mbit/sec, defined in USB 1.0 + kLow = 1, + // Full speed is 11 Mbit/sec, defined in USB 1.0 + kFull = 2, + + // High speed is 480 Mbit/sec, defined in USB 2.0 + kHigh = 3, + + // Superspeed is 5 Gbit/sec, defined in USB 3.0. + kSuper = 4, + + // Superspeed+ is 10 Gbit/sec, defined in USB 3.1. + kSuperPlus = 5, + }; + + // Descriptor type, as defined in USB spec. + enum class DescriptorType { + // Device descriptor. + kDevice = 1, + // Configuration descriptor. + kConfig = 2, + // String descriptor. + kString = 3, + // Interface descriptor. + kInterface = 4, + // Endpoint descriptor. + kEndpoint = 5, + // Device qualifier. + kDeviceQualifier = 6, + // Other speed configuration. + kOtherSpeedConfiguration = 7, + // BOS descriptor. + kBos = 0xf, + // Device capability descriptor. + kDeviceCapability = 0x10, + // DFU functional descriptor. + kDfuFunctional = 0x21, + // Super speed endpoint companion descriptor. + kSuperSpeedEndpointCompanion = 0x30, + }; + + // Detailed definition of this enum can be found in USB spec. + // Used in specifying the request type in setup packet. + enum class CommandDataDir { + // Data, if present in this command, flows from host to device + kHostToDevice = 0, + // Data, if present in this command, flows from device to host + kDeviceToHost = 1, + }; + + // Detailed definition of this enum can be found in USB spec. + // Used in specifying the request type in setup packet. + enum class CommandType { + // This is part of the standard command set every USB device has to support. + kStandard = 0, + // This is a class-specific command. + kClass = 1, + // This is a vendor-specific command. + kVendor = 2, + }; + + // Detailed definition of this enum can be found in USB spec. + // Used in specifying the request type in setup packet. + enum class CommandRecipient { + // The recipient of this command is the whole device. + kDevice = 0, + // The recipient of this command is the specified interface. + kInterface = 1, + // The recipient of this command is the specified endpoint. + kEndpoint = 2, + // The recipient of this command is none of the above. + kOther = 3, + }; + + // Setup packet is used in all commands sent over control endpoint 0. + // Detailed definition can be found in USB spec. Definition of most fields are + // command-specific, so the comments here only provide the general concept. + struct SetupPacket { + // Type, direction, and recipient of this command. + uint8_t request_type; + // The actual request ID. + uint8_t request; + // General field used to carry parameter in this command. + uint16_t value; + // Usually used to specify the interface number for this command. + uint16_t index; + // The amount of data, in number of bytes, for data phase associated with + // this command. + uint16_t length; + }; + + // Options available when closing the device. + enum class CloseAction { + // Closes the device. The same device can be opened right away. + kNoReset = 0, + + // Performs USB port reset before closing the device. USB bus re-enumeration + // could take some time, before the same, or a different device can be + // opened again. + kGracefulPortReset, + + // Perform emergency USB port reset without first releasing all interfaces, + // and then closes the device. Some resource leak or incompatibility + // with underlying OS could arise. + kForcefulPortReset, + + // Performs chip reset before closing the device. USB bus re-enumeration + // could take some time, before the same, or a different device can be + // opened again. Locally connected devices perform kPortReset instead. + kGracefulChipReset, + + // Perform emergency whole chip reset without first releasing all + // interfaces, and then closes the device. Some resource leak or + // incompatibility with underlying OS could arise. + kForcefulChipReset, + }; + + // Completion callback made when data in has been completed. + // Note that short data transfer is not considered an error, so application + // must check the amount received. + // This callback receives two arguments, the number of bytes transferred, and + // resulting status of the data in request. + using DataInDone = std::function; + + // Completion callback made when data out has been completed. + // Note that short data transfer is considered an error. + using DataOutDone = std::function; + + // Used to specify timeout, in number of milliseconds. + using TimeoutMillis = int; + + // Constant buffer, which is used to send data to device (USB OUT). + using ConstBuffer = gtl::ArraySlice; + + // Mutable buffer, which is used for receiving data from device (USB IN). + using MutableBuffer = gtl::MutableArraySlice; + + // Timeout, in milliseconds, to be used in USB operations. + enum TimeoutSpec : TimeoutMillis { + kDoNotRetry = 0, + kTimeoutOneSecond = 1000, + }; + + UsbDeviceInterface() = default; + + // This class is neither copyable nor movable. + UsbDeviceInterface(const UsbDeviceInterface&) = delete; + UsbDeviceInterface& operator=(const UsbDeviceInterface&) = delete; + + virtual ~UsbDeviceInterface() = default; + + // Closes the device and releases all associated resources. + virtual util::Status Close(CloseAction action) = 0; + + // Sets the active configuration. Application must not assume a default + // configuration is already being set to active. Valid configuration starts + // with 1. Setting configuration to -1 would set the device into unconfigured + // state. All claimed interfaces must be released before one can change or + // reset configuration. + virtual util::Status SetConfiguration(int configuration) = 0; + + // Notifies underlying OS that this application intends to use this interface + // in current configuration. + virtual util::Status ClaimInterface(int interface_number) = 0; + + // Releases ownership of this interface in current configuration. + virtual util::Status ReleaseInterface(int interface_number) = 0; + + // Retrieves the specified descriptor from device. + virtual util::Status GetDescriptor(DescriptorType desc_type, + uint8_t desc_index, MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) = 0; + + virtual DeviceSpeed GetDeviceSpeed() const { return DeviceSpeed::kUnknown; } + + // Composes request type in setup packet for USB commands. + // This is an utility function for subclasses to compose setup packets. + static uint8_t ComposeUsbRequestType(CommandDataDir dir, CommandType type, + CommandRecipient recipient) { + constexpr int DataDirBitShift = 7; + constexpr int TypeBitShift = 5; + return (static_cast(dir) << DataDirBitShift) | + (static_cast(type) << TypeBitShift) | + static_cast(recipient); + } + + // Sets control command over endpoint 0, with no data phase + virtual util::Status SendControlCommand(const SetupPacket& command, + TimeoutMillis timeout_msec, + const char* context) = 0; + + // Sets control command over endpoint 0, with data out. + virtual util::Status SendControlCommandWithDataOut(const SetupPacket& command, + ConstBuffer data_out, + TimeoutMillis timeout_msec, + const char* context) = 0; + + // Sets control command over endpoint 0, with data in. + virtual util::Status SendControlCommandWithDataIn( + const SetupPacket& command, MutableBuffer data_in, + size_t* num_bytes_transferred, TimeoutMillis timeout_msec, + const char* context) = 0; + + // Transfers data on the specified bulk out endpoint. + // This function returns after the bulk out has been done. Short transfer + // is considered as an error. + virtual util::Status BulkOutTransfer(uint8_t endpoint, ConstBuffer data_out, + TimeoutMillis timeout_msec, + const char* context) = 0; + + // Transfers data on the specified bulk in endpoint. + // This function returns after the bulk in has been done. Short transfer + // is expected and the number of bytes transferred is returned through + // *num_bytes_transferred. + virtual util::Status BulkInTransfer(uint8_t endpoint, MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) = 0; + + // Transfers data on the specified interrupt in endpoint. + // This function returns after the interrupt in has been done. Short transfer + // is expected and the number of bytes transferred is returned through + // *num_bytes_transferred. + virtual util::Status InterruptInTransfer(uint8_t endpoint, + MutableBuffer data_in, + size_t* num_bytes_transferred, + TimeoutMillis timeout_msec, + const char* context) = 0; + + // Transfers data on the specified bulk out endpoint. + // This function returns immediately after the data buffer is submitted into + // lower layer. A callback will be made, most probably from another thread, + // after the actual transfer is done. + virtual util::Status AsyncBulkOutTransfer(uint8_t endpoint, + ConstBuffer data_out, + TimeoutMillis timeout_msec, + DataOutDone callback, + const char* context) = 0; + + // Transfers data on the specified bulk in endpoint. + // This function returns immediately after the data buffer is submitted into + // lower layer. A callback will be made, most probably from another thread, + // after the actual transfer is done. + virtual util::Status AsyncBulkInTransfer(uint8_t endpoint, + MutableBuffer data_in, + TimeoutMillis timeout_msec, + DataInDone callback, + const char* context) = 0; + + // Transfers data on the specified interrupt in endpoint. + // This function returns immediately after the data buffer is submitted into + // lower layer. A callback will be made, most probably from another thread, + // after the actual transfer is done. + virtual util::Status AsyncInterruptInTransfer(uint8_t endpoint, + MutableBuffer data_in, + TimeoutMillis timeout_msec, + DataInDone callback, + const char* context) = 0; + + // Cancels all current transfers. This is a best-effort request. + virtual void TryCancelAllTransfers() = 0; + + // Allocates transfer buffer for subsequent data transfer. + // This is only useful in locally connected cases, and only if the underlying + // libusb and OS both support zero-copy on USB data transfer. + // If supported the amount of memory available could be limited by USB driver + // in kernel space. + // If not supported, the allocation would still be emulated in user space, + // and data might have to be copied between user and kernel space. + virtual util::StatusOr AllocateTransferBuffer( + size_t buffer_size) = 0; + + // Releases transfer buffer previously allocated. + // CloseDevice automatically releases all transfer buffers associated with the + // device. + virtual util::Status ReleaseTransferBuffer(MutableBuffer buffer) = 0; +}; + +// This interface abstracts the enumeration for connected USB devices. +// +// It is possible to connect to more than one devices at the same time, which +// could be interesting in some use cases. Extensions have to be +// made to the OpenDevice family of API to open multiple devices with the same +// vendor and product IDs. +class UsbManager { + public: + // Used to specify timeout, in number of milliseconds. + using TimeoutMillis = int; + + UsbManager() = default; + + // This class is neither copyable nor movable. + UsbManager(const UsbManager&) = delete; + UsbManager& operator=(const UsbManager&) = delete; + + virtual ~UsbManager() = default; + + // Opens a new device and returns an instance of UsbDeviceInterface for that + // discovered device. If there are multiple of them connected, only the first + // one is opened. Automatic retry could be made within timeout limit. + // timeout_msec could be 0, which means no retry is allowed. + // Negative timeout_msec, if supported, means unlimited/very long timeout and + // retry. + virtual util::StatusOr> OpenDevice( + uint16_t vendor_id, uint16_t product_id, TimeoutMillis timeout_msec) = 0; + + // Equivalent to OpenDevice() above, but without product_id specifier. The + // first device found with the specified vendor ID is opened, regardless of + // the product ID. Order of bus enumeration through this API is unreliable, + // and hence it's not guaranteed to open the same device everytime if more + // than one devices of the same vendor ID are present. + virtual util::StatusOr> OpenDevice( + uint16_t vendor_id, TimeoutMillis timeout_msec) = 0; + + static constexpr TimeoutMillis kDoNotRetry = 0; +}; + +// Factory class to produce path strings for devices connected to USB, +// and create device objects from the path strings. Thread-safe. +class UsbDeviceFactory { + public: + // Used to specify timeout, in number of milliseconds. + using TimeoutMillis = int; + + UsbDeviceFactory() = default; + + virtual ~UsbDeviceFactory() = default; + + // This class is neither copyable nor movable. + UsbDeviceFactory(const UsbDeviceFactory&) = delete; + UsbDeviceFactory& operator=(const UsbDeviceFactory&) = delete; + + // On success, returns a vector of strings for all connected USB devices + // matching the vendor and product ID specified. The strings are + // system-specific, but not limited to a particular factory instance. + virtual util::StatusOr> EnumerateDevices( + uint16_t vendor_id, uint16_t product_id) = 0; + + // Creates object implementing UsbDeviceInterface from the specified path + // string. The timeout is meant for the enumerating and opening operation. + virtual util::StatusOr> OpenDevice( + const std::string& path, TimeoutMillis timeout_msec) = 0; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_DEVICE_INTERFACE_H_ diff --git a/driver/usb/usb_dfu_commands.cc b/driver/usb/usb_dfu_commands.cc new file mode 100644 index 0000000..36a336d --- /dev/null +++ b/driver/usb/usb_dfu_commands.cc @@ -0,0 +1,539 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_dfu_commands.h" + +#include +#include + +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_standard_commands.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +UsbDfuCommands::UsbDfuCommands(std::unique_ptr device, + TimeoutMillis default_timeout_msec) + : UsbStandardCommands(std::move(device), default_timeout_msec) { + VLOG(10) << __func__; +} + +UsbDfuCommands::~UsbDfuCommands() { VLOG(10) << __func__; } + +util::Status UsbDfuCommands::DfuDetach(uint16_t timeout_msec) { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + SetupPacket command{ + // Request type (00100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuDetach), + + // Timeout in milliseconds. + timeout_msec, + + // Interface number. + dfu_interface_number_, + + // Data length. + 0}; + + return SendControlCommand(command, __func__); +} + +void UsbDfuCommands::SetDfuInterface(int interface_number) { + StdMutexLock lock(&mutex_); + dfu_interface_number_ = static_cast(interface_number); + VLOG(5) << StringPrintf("%s set to %u", __func__, dfu_interface_number_); +} + +util::StatusOr, + UsbDfuCommands::DfuFunctionalDescriptor>> +UsbDfuCommands::FindDfuInterfaces( + const std::vector& raw_configuration_descriptor) { + constexpr size_t kConfigDescriptorRawByteSize = 9; + constexpr size_t kInterfaceDescriptorRawByteSize = 9; + if (raw_configuration_descriptor.size() < kConfigDescriptorRawByteSize) { + return util::InvalidArgumentError("Raw data is way too short"); + } + + const uint8_t reported_config_type = raw_configuration_descriptor[1]; + if (reported_config_type != + static_cast(UsbDeviceInterface::DescriptorType::kConfig)) { + return util::InvalidArgumentError("Not reported as config descriptor"); + } + + const uint8_t reported_config_length = raw_configuration_descriptor[0]; + const uint8_t reported_total_data_length = raw_configuration_descriptor[2]; + if (reported_total_data_length > raw_configuration_descriptor.size()) { + return util::InvalidArgumentError("Incomplete config descriptor"); + } + // Every configuration must has at least one interface. + if (reported_total_data_length < + kConfigDescriptorRawByteSize + kInterfaceDescriptorRawByteSize) { + return util::InvalidArgumentError("Reported total data is way too short"); + } + + bool found_dfu_functional_descriptor = false; + std::list dfu_interfaces; + DfuFunctionalDescriptor dfu_functional_descriptor; + size_t cursor = reported_config_length; + do { + VLOG(10) << StringPrintf("%s cursor %u", __func__, + static_cast(cursor)); + if ((cursor + 1) >= raw_configuration_descriptor.size()) break; + const uint8_t length = raw_configuration_descriptor[cursor]; + const uint8_t type = raw_configuration_descriptor[cursor + 1]; + VLOG(10) << StringPrintf("%s type 0x%x, length %u", __func__, type, length); + + if (length == 0) { + return util::FailedPreconditionError( + "Length of functional descriptor must not be 0"); + } + + if (type == + static_cast(UsbDeviceInterface::DescriptorType::kInterface)) { + // Treat this as an interface descriptor. + // Make sure we have valid access to the whole descriptor. + if ((cursor + kInterfaceDescriptorRawByteSize - 1) >= + raw_configuration_descriptor.size()) { + break; + } + + // TODO consider changing numbers to named constants. + InterfaceDescriptor interface; + interface.interface_number = raw_configuration_descriptor[cursor + 2]; + interface.alternate_setting = raw_configuration_descriptor[cursor + 3]; + interface.num_endpoints = raw_configuration_descriptor[cursor + 4]; + interface.interface_class = raw_configuration_descriptor[cursor + 5]; + interface.interface_subclass = raw_configuration_descriptor[cursor + 6]; + interface.interface_protocol = raw_configuration_descriptor[cursor + 7]; + interface.interface_name_index = raw_configuration_descriptor[cursor + 8]; + + VLOG(10) << StringPrintf( + "%s interface %d, alternate settings %u, num of extra endpoints %u, " + "class 0x%x, subclass 0x%x", + __func__, interface.interface_number, interface.alternate_setting, + interface.num_endpoints, interface.interface_class, + interface.interface_subclass); + + if ((interface.num_endpoints == 0) && + (interface.interface_class == + /*Application specific, see DFU spec 1.1*/ 0xFE) && + (interface.interface_subclass == /*DFU, see DFU spec 1.1*/ 1)) { + dfu_interfaces.push_back(interface); + } + + } else if (type == + static_cast( + UsbDeviceInterface::DescriptorType::kDfuFunctional)) { + // Make sure we have valid access to the whole descriptor. + constexpr size_t kDfuFunctionalDescriptorRawByteSize = 9; + if ((cursor + kDfuFunctionalDescriptorRawByteSize - 1) >= + raw_configuration_descriptor.size()) + break; + + found_dfu_functional_descriptor = true; + + // TODO consider adding named constants for numbers used in + // parsing. + // Fill fields from raw bytes to more accessible data structure. + // memcpy is used for all multi-byte data fields to avoid any potential + // alignment issues. + const uint8_t attributes = raw_configuration_descriptor[cursor + 2]; + dfu_functional_descriptor.will_detach = + static_cast(attributes & 0x8); + dfu_functional_descriptor.manifestation_tolerant = + static_cast(attributes & 0x4); + dfu_functional_descriptor.can_upload = + static_cast(attributes & 0x2); + dfu_functional_descriptor.can_download = + static_cast(attributes & 0x1); + memcpy(&dfu_functional_descriptor.detach_timeout_msec, + &raw_configuration_descriptor[cursor + 3], 2); + memcpy(&dfu_functional_descriptor.transfer_size, + &raw_configuration_descriptor[cursor + 5], 2); + memcpy(&dfu_functional_descriptor.dfu_version_bcd, + &raw_configuration_descriptor[cursor + 7], 2); + + VLOG(7) << StringPrintf("Will detach: %d, manifestation tolerant: %d", + dfu_functional_descriptor.will_detach, + dfu_functional_descriptor.manifestation_tolerant); + VLOG(7) << StringPrintf("Can upload: %d, can download: %d", + dfu_functional_descriptor.can_upload, + dfu_functional_descriptor.can_download); + VLOG(7) << StringPrintf("Transfer Size: 0x%x", + dfu_functional_descriptor.transfer_size); + VLOG(7) << StringPrintf("Detach Timeout: 0x%x", + dfu_functional_descriptor.detach_timeout_msec); + VLOG(7) << StringPrintf("DFU version in BCD: 0x%x", + dfu_functional_descriptor.dfu_version_bcd); + } else { + // Skip unrecognized entries. + } + cursor += length; + } while (true); + + if ((!dfu_interfaces.empty()) && found_dfu_functional_descriptor) { + return std::make_pair(std::move(dfu_interfaces), dfu_functional_descriptor); + } + + return util::NotFoundError(__func__); +} + +util::StatusOr UsbDfuCommands::DfuGetStatus() { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + constexpr size_t kGetStatusRawByteSize = 6; + uint8_t buffer[kGetStatusRawByteSize] = {0}; + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kDeviceToHost, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuGetStatus), + + // Value is not used. + 0, + + // Interface number. + dfu_interface_number_, + + // Data length. + sizeof(buffer)}; + + size_t num_bytes_transferred = 0; + RETURN_IF_ERROR(SendControlCommandWithDataIn( + command, MutableBuffer(buffer, sizeof(buffer)), &num_bytes_transferred, + __func__)); + + if (num_bytes_transferred != sizeof(buffer)) { + return util::UnknownError("Invalid DFU status data"); + } + + DfuStatus dfu_status; + + // Initialize all fields to 0. The common case {0} doesn't work here, as + // scoped enum doesn't have implicit conversion. + memset(&dfu_status, 0, sizeof(dfu_status)); + + // Fill fields from raw bytes to more accessible data structure. + // memcpy is used for all multi-byte data fields to avoid any potential + // alignment issues. + dfu_status.previous_result = static_cast(buffer[0]); + + // TODO consider adding named constants for numbers used. + // Note this field is only 3-byte-long. Since we assume everything is little + // endian, only the first 3 bytes on the target are overwritten. + memcpy(&dfu_status.poll_timeout_msec, &buffer[1], 3); + + dfu_status.state = static_cast(buffer[4]); + dfu_status.status_string_index = buffer[5]; + + VLOG(7) << StringPrintf("Previous result: %d", + static_cast(dfu_status.previous_result)); + VLOG(7) << StringPrintf("Poll timeout: %d", dfu_status.poll_timeout_msec); + VLOG(7) << StringPrintf("State: %d", static_cast(dfu_status.state)); + VLOG(7) << StringPrintf("Status string index: %d", + dfu_status.status_string_index); + + return dfu_status; +} + +util::Status UsbDfuCommands::DfuClearStatus() { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + SetupPacket command{ + // Request type (00100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuClearStatus), + + // Value is not used. + 0, + + // Interface number. + dfu_interface_number_, + + // Data length. + 0}; + + return SendControlCommand(command, __func__); +} + +util::Status UsbDfuCommands::DfuAbort() { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + SetupPacket command{ + // Request type (00100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuAbort), + + // Value is not used. + 0, + + // Interface number. + dfu_interface_number_, + + // Data length. + 0}; + + return SendControlCommand(command, __func__); +} + +util::StatusOr UsbDfuCommands::DfuGetState() { + VLOG(10) << __func__; + StdMutexLock lock(&mutex_); + constexpr size_t kGetStateRawByteSize = 1; + uint8_t buffer[kGetStateRawByteSize] = {0}; + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kDeviceToHost, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuGetState), + + // Value is not used. + 0, + + // Interface number. + dfu_interface_number_, + + // Data length. + sizeof(buffer)}; + + size_t num_bytes_transferred = 0; + RETURN_IF_ERROR(SendControlCommandWithDataIn( + command, MutableBuffer(buffer, sizeof(buffer)), &num_bytes_transferred, + __func__)); + + if (num_bytes_transferred != sizeof(buffer)) { + return util::UnknownError("Invalid DFU state data"); + } + + State dfu_state = static_cast(buffer[0]); + + VLOG(7) << StringPrintf("State: %d", static_cast(dfu_state)); + + return dfu_state; +} + +util::Status UsbDfuCommands::DfuDownloadBlock(uint16_t block_number, + ConstBuffer block_buffer) { + VLOG(10) << StringPrintf("%s block %u, request size %u", __func__, + block_number, + static_cast(block_buffer.size())); + + StdMutexLock lock(&mutex_); + + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuDownload), + + // Block number. + block_number, + + // Interface number. + dfu_interface_number_, + + // Data length. + static_cast(block_buffer.size())}; + + return SendControlCommandWithDataOut(command, block_buffer, __func__); +} + +util::Status UsbDfuCommands::DfuUploadBlock(uint16_t block_number, + MutableBuffer block_buffer, + size_t* num_bytes_transferred) { + VLOG(10) << StringPrintf("%s block %u, request size %u", __func__, + block_number, + static_cast(block_buffer.size())); + + StdMutexLock lock(&mutex_); + + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kDeviceToHost, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id. + static_cast(RequestId::kDfuUpload), + + // Block number. + block_number, + + // Interface number. + dfu_interface_number_, + + // Data length. + static_cast(block_buffer.size())}; + + // command.length = static_cast(block_buffer.size()); + + return SendControlCommandWithDataIn(command, block_buffer, + num_bytes_transferred, __func__); +} + +util::Status UsbDfuCommands::UpdateFirmware( + const DfuFunctionalDescriptor& descriptor, ConstBuffer firmware_image) { + VLOG(7) << StringPrintf("%s Downloading firmware", __func__); + + if (firmware_image.empty()) { + return util::InvalidArgumentError("Invalid DFU image file"); + } + + VLOG(7) << StringPrintf("%s Firmware image size %zu bytes", __func__, + firmware_image.size()); + + // TODO: try DFU abort or clear status to clear the stage, if we're + // not in DFU idle state. + + uint16 block_number = 0; + size_t num_bytes_transferred = 0; + bool last_packet_sent = false; + while (num_bytes_transferred <= firmware_image.size()) { + const uint16 transfer_size = static_cast( + std::min(static_cast(descriptor.transfer_size), + firmware_image.size() - num_bytes_transferred)); + + if (transfer_size == 0) { + VLOG(8) << StringPrintf("%s Sending the final zero-length packet", + __func__); + } else { + VLOG(8) << StringPrintf( + "%s Transfer size %u bytes, already transferred %zu bytes", __func__, + transfer_size, num_bytes_transferred); + } + + RETURN_IF_ERROR(DfuDownloadBlock( + block_number, + ConstBuffer(firmware_image, num_bytes_transferred, transfer_size))); + + auto dfu_status_query_result = DfuGetStatus(); + RETURN_IF_ERROR(dfu_status_query_result.status()); + auto dfu_status = dfu_status_query_result.ValueOrDie(); + + VLOG(8) << StringPrintf("%s: block %d status:%d, state:%d", __func__, + block_number, + static_cast(dfu_status.previous_result), + static_cast(dfu_status.state)); + + if ((Error::kOK == dfu_status.previous_result) && + (State::kDownloadIdle == dfu_status.state)) { + // keep track of accumulated data + num_bytes_transferred += transfer_size; + } else if ((0 == transfer_size) && + (Error::kOK == dfu_status.previous_result) && + (State::kDfuIdle == dfu_status.state)) { + // The last packet sent has zero length. Downloading is done. + last_packet_sent = true; + break; + } else { + VLOG(8) << StringPrintf("%s: download failed", __func__); + break; + } + // block number could wrap around + ++block_number; + } + + VLOG(7) << StringPrintf("%s, transferred image size: %zu, EOF: %d", __func__, + num_bytes_transferred, + (firmware_image.size() == num_bytes_transferred)); + + if (last_packet_sent) { + return util::Status(); // OK. + } + + return util::DataLossError("Firmware downloading failed"); +} + +util::Status UsbDfuCommands::ValidateFirmware( + const DfuFunctionalDescriptor& descriptor, ConstBuffer firmware_image) { + VLOG(7) << StringPrintf("%s Validating firmware", __func__); + + uint16_t block_number = 0; + bool short_packet_received = false; + std::vector upload_image; + upload_image.reserve(firmware_image.size()); + std::vector chunk_buffer(descriptor.transfer_size); + while (true) { + // Always asks for max transfer size. + const uint16 transfer_size = static_cast(descriptor.transfer_size); + + VLOG(10) << StringPrintf("%s Reading firmware block %d", __func__, + block_number); + + size_t chunk_bytes_transferred = 0; + RETURN_IF_ERROR(DfuUploadBlock( + block_number, MutableBuffer(chunk_buffer.data(), chunk_buffer.size()), + &chunk_bytes_transferred)); + + upload_image.insert(upload_image.end(), chunk_buffer.begin(), + chunk_buffer.begin() + chunk_bytes_transferred); + + if (chunk_bytes_transferred < transfer_size) { + // A short packet! Upload is done. + short_packet_received = true; + break; + } + // block number could wrap around. + ++block_number; + } + + VLOG(7) << StringPrintf("%s, Uploaded image size: %zu", __func__, + upload_image.size()); + + if (upload_image.size() < firmware_image.size()) { + VLOG(1) << StringPrintf("%s, Uploaded image is shorter than expected", + __func__); + return util::DataLossError(__func__); + } + + // Only compares the first part of uploaded image, for it's possible for + // the uploaded images to be longer than the reference image. + if (0 == memcmp(upload_image.data(), firmware_image.data(), + firmware_image.size())) { + return util::Status(); // OK. + } + + VLOG(1) << StringPrintf("%s, Uploaded image is different from expected", + __func__); + + return util::DataLossError(__func__); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_dfu_commands.h b/driver/usb/usb_dfu_commands.h new file mode 100644 index 0000000..49d6d4a --- /dev/null +++ b/driver/usb/usb_dfu_commands.h @@ -0,0 +1,257 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_DFU_COMMANDS_H_ +#define DARWINN_DRIVER_USB_USB_DFU_COMMANDS_H_ + +#include +#include // NOLINT + +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_standard_commands.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Thread-safe implementation of USB Device Firmare Update protocol. +// Note thread-safety doesn't mean much here, as the device cannot respond +// to interferences properly during DFU process. +// TODO provide a mechanism like locked/unlocked versions of functions +// and a busy state to prevent any interruption in the middle of a long sequence +// like firmware update. +class UsbDfuCommands : public UsbStandardCommands { + public: + // Detailed definition of these request IDs can be found in DFU spec. + enum class RequestId { + // Pushes a device into app detached state. + kDfuDetach = 0, + + // Sends one chunk of firmware to device. + kDfuDownload = 1, + + // Retrieves one chunk of firmware from device. + kDfuUpload = 2, + + // Retrieves DFU status from device. + kDfuGetStatus = 3, + + // Clears error status in DFU mode. + kDfuClearStatus = 4, + + // Retrieves DFU state without affecting the state. + kDfuGetState = 5, + + // Aborts current DFU operation. + kDfuAbort = 6, + }; + + // Detail definition of these states can be found in DFU spec v1.1. + enum class State { + // Normal/idle in application mode. + kAppIdle = 0, + + // Detached in application mode, waiting for USB reset to enter DFU mode. + kAppDetach = 1, + + // Normal/idle in DFU mode. + kDfuIdle = 2, + + // Downloading in DFU mode, waiting for GetStatus. + kDownloadSync = 3, + + // Downloading in DFU mode, blocking further GetStatus. + kDownloadBusy = 4, + + // Downloading in DFU mode, waiting for the next packet. + kDownloadIdle = 5, + + // Programming in DFU mode, waiting for the last GetStatus to begin + // manifest phase. + kManifestSync = 6, + + // Programming in DFU mode. + kManifest = 7, + + // Programming in DFU mode, waiting for the USB reset to leave DFU mode. + kManifestWaitReset = 8, + + // Uploading in DFU mode, waiting for the next DfuUpload. + kUploadIdle = 9, + + // Error state in DFU mode, waiting for ClearStatus. + kError = 10, + }; + + // Detail definition of these error codes can be found in DFU spec v1.1. + enum class Error { + // No error. + kOK = 0, + + // File is not targeted for this device. + kWrongTarget = 1, + + // Vendor-specific verification failed. + kFileVerifyFailed = 2, + + // Write memory failed. + kWriteFailed = 3, + + // Erase memory failed. + kEraseFailed = 4, + + // Check failed for erasing memory. + kEraseCheckFailed = 5, + + // Program memory failed. + kProgramFailed = 6, + + // Check failed for programming memory. + kProgramVerifyFailed = 7, + + // Program failed because of address is invalid. + kInvalidAddress = 8, + + // The downloaded firmware image doesn't seem to be long enough. + kInsufficientData = 9, + + // The firmware is corrupted. + kFirmwareIsCorrupt = 10, + + // Vendor-specific error. + kVendorSpecificError = 11, + + // Unexpected USB reset detected. + kUnexpectedUsbResetDetected = 12, + + // Unexpected POR detected. + kUnexpectedPowerOnResetDetected = 13, + + // Unknown error. + kUnknownError = 14, + + // Unexpected request. + kUnexpectedRequestStalled = 15, + }; + + // Current device status returned by DFU GetStatus command. + // Detailed definition of these fields can be found in DFU spec v1.1. + struct DfuStatus { + // Status before executing this GetStatus command. + Error previous_result; + + // Minimum time the host should wait before a subsequent GetStatus command. + // Valid range is 0-0xFFFFFF only + uint32_t poll_timeout_msec; + + // State after executing this GetStatus command. + State state; + + // Index of status description in the string table. + uint8_t status_string_index; + }; + + // Functional descriptor for DFU funtction. + // Detailed definition of these fields can be found in DFU spec v1.1. + struct DfuFunctionalDescriptor { + // True if host must not send USB reset after DFU_DETACH command. + bool will_detach; + + // True if device goes back to DFU Idle after manifestation. + bool manifestation_tolerant; + + // True if device can upload firmware image to host. + bool can_upload; + + // True if device can download firmware image from host. + bool can_download; + + // Max time, in msec, before device returns to App Idle mode from App + // Detach. + uint16_t detach_timeout_msec; + + // Max number of bytes in each control read/write request. + // This number should be larger than max packet size for ep 0. + uint16_t transfer_size; + + // DFU version supported in BCD. This field must be at least 0100h. + uint16_t dfu_version_bcd; + }; + + // Constructs a new object from pointer to an USB device. + UsbDfuCommands(std::unique_ptr device, + TimeoutMillis default_timeout_msec); + + // This class is neither copyable but movable. + UsbDfuCommands(const UsbDfuCommands&) = delete; + UsbDfuCommands& operator=(const UsbDfuCommands&) = delete; + + ~UsbDfuCommands() override; + + // Gets DFU functional descriptor from device. + util::StatusOr< + std::pair, DfuFunctionalDescriptor>> + FindDfuInterfaces(const std::vector& raw_configuration_descriptor); + + // Sets the target interface number for DFU interface-specific commands, + // including DfuGetStatus, DfuClearStatus, DfuAbort,DfuGetState, + // DfuDownloadBlock, and DfuUploadBlock. + // Only the low 16-bit is used, as per USB spec. + void SetDfuInterface(int interface_number) LOCKS_EXCLUDED(mutex_); + + // Detaches from application mode. + util::Status DfuDetach(uint16_t timeout_msec); + + // Retrieves DFU status from device. + util::StatusOr DfuGetStatus() LOCKS_EXCLUDED(mutex_); + + // Clears error status in DFU mode. + util::Status DfuClearStatus() LOCKS_EXCLUDED(mutex_); + + // Aborts current DFU operation. + util::Status DfuAbort() LOCKS_EXCLUDED(mutex_); + + // Retrieves DFU state from device without affecting the virtual state. + util::StatusOr DfuGetState() LOCKS_EXCLUDED(mutex_); + + // Downloads a block of firmware from host to device. + util::Status DfuDownloadBlock(uint16_t block_number, ConstBuffer block_buffer) + LOCKS_EXCLUDED(mutex_); + + // Uploads a block of firmware frmo device to host. + util::Status DfuUploadBlock(uint16_t block_number, MutableBuffer block_buffer, + size_t* num_bytes_transferred) + LOCKS_EXCLUDED(mutex_); + + util::Status UpdateFirmware(const DfuFunctionalDescriptor& descriptor, + ConstBuffer firmware_image) + LOCKS_EXCLUDED(mutex_); + + util::Status ValidateFirmware(const DfuFunctionalDescriptor& descriptor, + ConstBuffer firmware_image) + LOCKS_EXCLUDED(mutex_); + + private: + // Serializes access to this interface and hence shared data. + mutable std::mutex mutex_; + + uint16_t dfu_interface_number_ GUARDED_BY(mutex_){0}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_DFU_COMMANDS_H_ diff --git a/driver/usb/usb_dfu_util.cc b/driver/usb/usb_dfu_util.cc new file mode 100644 index 0000000..6fd8eb2 --- /dev/null +++ b/driver/usb/usb_dfu_util.cc @@ -0,0 +1,146 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_dfu_util.h" + +#include +#include + +#include "port/errors.h" +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" +#include "port/time.h" +#include "port/tracing.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// TODO: Consider an absl::Duration instead +// Sleep time after a reset for DFU. +// TODO: revisit this setting after we finalize PHY tuning. +constexpr int kSleepTimeSecondsAfterReset = 4; + +// TODO add proper error handling to this function. +// Convenience function to read a file to a vector. +std::vector ReadToVector(const std::string& file_name) { + VLOG(10) << __func__ << file_name; + + // TODO directly read into the vector instead of transcopying through + // a string. + std::ifstream ifs(file_name); + std::string content_string((std::istreambuf_iterator(ifs)), + (std::istreambuf_iterator())); + + std::vector result; + auto data = reinterpret_cast(content_string.c_str()); + result.insert(result.end(), data, data + content_string.size()); + return result; +} + +} // namespace + +util::Status UsbUpdateDfuDevice(UsbDfuCommands* dfu_device, + UsbDeviceInterface::ConstBuffer firmware_image, + bool skip_verify) { + TRACE_SCOPE("UsbUpdateDfuDevice"); + + VLOG(10) << StringPrintf("%s Loading descriptor for the first configuration", + __func__); + + constexpr size_t kMaxConfigDescriptorAllowed = 512; + + ASSIGN_OR_RETURN(UsbDfuCommands::ConfigurationDescriptor config_descriptor, + dfu_device->GetConfigurationDescriptor( + UsbDeviceInterface::kFirstDeviceConfiguration, + kMaxConfigDescriptorAllowed)); + + ASSIGN_OR_RETURN(auto dfu_interfaces, + dfu_device->FindDfuInterfaces(config_descriptor.raw_data)); + + const int dfu_interface = dfu_interfaces.first.begin()->interface_number; + + VLOG(10) << StringPrintf( + "%s Num of DFU interfaces %zu, claiming interface %d", __func__, + dfu_interfaces.first.size(), dfu_interface); + + RETURN_IF_ERROR(dfu_device->ClaimInterface(dfu_interface)); + + dfu_device->SetDfuInterface(dfu_interface); + + RETURN_IF_ERROR( + dfu_device->UpdateFirmware(dfu_interfaces.second, firmware_image)); + + if (skip_verify) { + return util::Status(); // OK. + } else { + return dfu_device->ValidateFirmware(dfu_interfaces.second, firmware_image); + } +} + +util::Status UsbUpdateAllDfuDevices(UsbManager* usb_manager, uint16_t vendor_id, + uint16_t product_id, + const std::string& firmware_filename, + bool skip_verify) { + VLOG(7) << StringPrintf("%s Downloading firmware file:%s", __func__, + firmware_filename.c_str()); + + auto firmware_image = ReadToVector(firmware_filename); + if (firmware_image.empty()) { + return util::InvalidArgumentError("Invalid DFU image file"); + } + + // Perform DFU on all devices that have the right product ID. + bool is_dfu_attempted = false; + constexpr int kMaxNumDfuRun = 10; + for (int dfu_count = 0; dfu_count < kMaxNumDfuRun; ++dfu_count) { + auto dfu_target = + usb_manager->OpenDevice(vendor_id, product_id, UsbManager::kDoNotRetry); + if (!dfu_target.ok()) { + // Leave this step if we couldn't find any device for DFU. + VLOG(7) << StringPrintf("%s No more device is in need for DFU", __func__); + break; + } + + is_dfu_attempted = true; + VLOG(7) << StringPrintf("%s Performing DFU on device %d", __func__, + dfu_count); + + UsbDfuCommands dfu_commands(std::move(dfu_target.ValueOrDie()), + UsbDeviceInterface::kTimeoutOneSecond); + // Return error if we encounter any error in the DFU process. + // Returning here avoid trying DFU on the same faulty + // device indefinitely. + RETURN_IF_ERROR( + UsbUpdateDfuDevice(&dfu_commands, firmware_image, skip_verify)); + RETURN_IF_ERROR( + dfu_commands.Close(UsbDfuCommands::CloseAction::kGracefulPortReset)); + } + + if (is_dfu_attempted) { + // Wait for short period of time so the devices could come back after reset. + VLOG(7) << StringPrintf( + "%s DFU completed. Waiting for devices to come back", __func__); + Sleep(kSleepTimeSecondsAfterReset); + } + + return util::Status(); // OK. +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_dfu_util.h b/driver/usb/usb_dfu_util.h new file mode 100644 index 0000000..fc48deb --- /dev/null +++ b/driver/usb/usb_dfu_util.h @@ -0,0 +1,43 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_DFU_UTIL_H_ +#define DARWINN_DRIVER_USB_USB_DFU_UTIL_H_ + +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_dfu_commands.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Performs DFU on device with specified firmware image. +util::Status UsbUpdateDfuDevice(UsbDfuCommands* dfu_device, + UsbDeviceInterface::ConstBuffer firmware_image, + bool skip_verify); + +// TODO: remove this function, as it's only used by the remote +// interface. +// Tries to perform DFU on all USB devices of the same vendor and +// product ID. +util::Status UsbUpdateAllDfuDevices(UsbManager* usb_manager, uint16_t vendor_id, + uint16_t product_id, + const std::string& firmware_filename, + bool skip_verify); + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_DFU_UTIL_H_ diff --git a/driver/usb/usb_driver.cc b/driver/usb/usb_driver.cc new file mode 100644 index 0000000..645ee5f --- /dev/null +++ b/driver/usb/usb_driver.cc @@ -0,0 +1,1772 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_driver.h" + +#include +#include +#include +#include +#include + +#include "api/buffer.h" +#include "api/watchdog.h" +#include "driver/device_buffer_mapper.h" +#include "driver/dma_info_extractor.h" +#include "driver/hardware_structures.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/memory/address_utilities.h" +#include "driver/memory/dram_allocator.h" +#include "driver/package_registry.h" +#include "driver/single_tpu_request.h" +#include "driver/time_stamper/time_stamper.h" +#include "driver/top_level_handler.h" +#include "driver/tpu_request.h" +#include "driver/usb/usb_dfu_util.h" +#include "driver/usb/usb_latest_firmware.h" +#include "driver/usb/usb_ml_commands.h" +#include "port/cleanup.h" +#include "port/errors.h" +#include "port/integral_types.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "port/time.h" +#include "port/tracing.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Sleep time before we try or retry to open a device. +// TODO: revisit this setting after we finalize PHY tuning. +constexpr int kSleepTimeMicroSecondsBeforeRetry = 1000000; + +// TODO: revisit this setting after we finalize PHY tuning. +constexpr int kMaxNumOfRetryAfterReset = 25; + +constexpr uint16_t kTargetAppVendorId = 0x18D1; +constexpr uint16_t kTargetAppProductId = 0x9302; + +constexpr uint16_t kTargetDfuVendorId = 0x1A6E; +constexpr uint16_t kTargetDfuProductId = 0x089A; + +// This class implements BasicLockable concept, to be used with +// std::conditional_variable_any. +// The implementation is specialized as no re-locking is needed. +// Since the conditional variable is used in the end of a do-while loop, +// re-locking is just waste of time. +class Lock2 { + public: + Lock2(StdCondMutexLock& m1, StdCondMutexLock& m2) : m1_(m1), m2_(m2) {} + + ~Lock2() = default; + + // Does nothing. This function is part of BasicLockable concept. + void lock() { + // do nothing. + VLOG(10) << "lock (does nothing)"; + } + + // Unlocks both Lockables. This function is part of BasicLockable concept. + void unlock() { + VLOG(10) << "Unlocks both mutex"; + m1_.unlock(); + m2_.unlock(); + } + + private: + StdCondMutexLock& m1_; + StdCondMutexLock& m2_; +}; + +// Returns the number of entries in a queue concept, protected by the given +// mutex. +template +static typename Queue::size_type QueueSize(const Queue* queue, + std::mutex* mutex) { + StdCondMutexLock state_lock(mutex); + return queue->size(); +} + +// Returns if a queue concept is empty, protected by the given mutex. +template +static bool IsQueueEmpty(const Queue* queue, std::mutex* mutex) { + StdCondMutexLock state_lock(mutex); + return queue->empty(); +} + +// Returns the first entry in a queue concept, protected by the given mutex. +template +static typename Queue::value_type QueuePop(Queue* queue, std::mutex* mutex) { + StdCondMutexLock state_lock(mutex); + typename Queue::value_type item = queue->front(); + queue->pop(); + return item; +} + +} // namespace + +UsbDriver::UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, std::unique_ptr time_stamper) + : Driver( + [](config::ChipConfig* chip_config) { + CHECK(chip_config != nullptr); + return chip_config->GetChip(); + }(chip_config.get()), + std::move(executable_registry), driver_options, + std::move(time_stamper)), + chip_config_(std::move(chip_config)), + registers_(std::move(registers)), + allocator_(gtl::MakeUnique( + chip_config_->GetChipStructures().allocation_alignment_bytes)), + top_level_interrupt_manager_(std::move(top_level_interrupt_manager)), + fatal_error_interrupt_controller_( + std::move(fatal_error_interrupt_controller)), + top_level_handler_(std::move(top_level_handler)), + dram_allocator_(std::move(dram_allocator)), + options_(options), + dma_info_extractor_( + options.usb_enable_processing_of_hints + ? DmaInfoExtractor::ExtractorType::kDmaHints + : DmaInfoExtractor::ExtractorType::kFirstInstruction, + options.usb_enable_overlapping_requests), + dma_scheduler_(api::Watchdog::MakeWatchdog( + driver_options.watchdog_timeout_ns(), + [this](int64) { HandleWatchdogTimeout(); })), + apex_csr_offsets_(chip_config_->GetApexCsrOffsets()), + cb_bridge_csr_offsets_(chip_config_->GetCbBridgeCsrOffsets()), + hib_kernel_csr_offsets_(chip_config_->GetHibKernelCsrOffsets()), + scu_csr_offsets_(chip_config_->GetScuCsrOffsets()), + usb_csr_offsets_(chip_config_->GetUsbCsrOffsets()), + hib_user_csr_offsets_(chip_config_->GetHibUserCsrOffsets()) { + run_controller_ = + gtl::MakeUnique(*chip_config_, registers_.get()); + + if (options_.mode == OperatingMode::kMultipleEndpointsSoftwareQuery) { + options_.usb_max_num_async_transfers = 1; + VLOG(5) << StringPrintf( + "force setting usb_max_num_async_transfers to 1 for software " + "query mode"); + } +} + +UsbDriver::UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::unique_ptr usb_device, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, std::unique_ptr time_stamper) + : UsbDriver(driver_options, std::move(chip_config), std::move(registers), + std::move(top_level_interrupt_manager), + std::move(fatal_error_interrupt_controller), + std::move(top_level_handler), std::move(dram_allocator), + std::move(executable_registry), options, + std::move(time_stamper)) { + usb_device_ = std::move(usb_device); +} + +UsbDriver::UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::function>()> + device_factory, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, std::unique_ptr time_stamper) + : UsbDriver(driver_options, std::move(chip_config), std::move(registers), + std::move(top_level_interrupt_manager), + std::move(fatal_error_interrupt_controller), + std::move(top_level_handler), std::move(dram_allocator), + std::move(executable_registry), options, + std::move(time_stamper)) { + device_factory_ = std::move(device_factory); +} + +UsbDriver::~UsbDriver() { + CHECK_OK(UnregisterAll()); + if (Close(api::Driver::ClosingMode::kGraceful).ok()) { + LOG(WARNING) << "Driver destroyed when open. Forced Close()."; + } +} + +util::Status UsbDriver::ValidateState(State expected_state) const { + return ValidateStates({expected_state}); +} + +util::Status UsbDriver::ValidateStates( + const std::vector& expected_states) const { + for (auto& state : expected_states) { + if (state_ == state) { + return util::Status(); // OK + } + } + + return util::FailedPreconditionError( + StringPrintf("Unexpected state %d.", state_)); +} + +util::Status UsbDriver::SetState(State next_state) { + driver_state_changed_.notify_all(); + + if ((next_state == kClosing) || (next_state == kPaused)) { + // Cancel all transfers when we enter closing or paused state. + // + // Cancellation generates new callbacks with canceled status which need to + // be handled for each transfer that is still active. Pointers to task + // records and hence bulk in/out requests could already been invalidated. + usb_device_->TryCancelAllTransfers(); + } + + switch (state_) { + case kOpen: + if ((next_state == kOpen) || (next_state == kClosing)) { + // There is nothing special to do. + state_ = next_state; + return util::Status(); // OK + } else if (next_state == kPaused) { + VLOG(7) << StringPrintf("%s try enable clock gating", __func__); + RETURN_IF_ERROR(top_level_handler_->EnableSoftwareClockGate()); + + state_ = next_state; + return util::Status(); // OK + } + break; + + case kPaused: + if (next_state == kPaused) { + // We're already paused. Do nothing. + return util::Status(); // OK + } else if ((next_state == kOpen) || (next_state == kClosing)) { + // Disable clock gating so we can access the chip. + VLOG(7) << StringPrintf("%s try disable clock gating", __func__); + RETURN_IF_ERROR(top_level_handler_->DisableSoftwareClockGate()); + + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosing: + if (next_state == kClosed) { + state_ = next_state; + return util::Status(); // OK + } + break; + + case kClosed: + if (next_state == kOpen) { + state_ = next_state; + return util::Status(); // OK + } + break; + } + + // Illegal state transition. + return util::FailedPreconditionError(StringPrintf( + "Invalid state transition. current=%d, next=%d.", state_, next_state)); +} + +// TODO: review the sequence with hardware team and convert them to +// use named constants. +util::Status UsbDriver::InitializeChip() { + TRACE_SCOPE("UsbDriver::InitializeChip"); + + ASSIGN_OR_RETURN(auto omc_reg, registers_->Read32(apex_csr_offsets_.omc0_00)); + constexpr int EFUSE_PROGRAMMING_REVISION_SHIFT = 24; + constexpr int EFUSE_PROGRAMMING_REVISION_MASK = 0xFF; + const uint8_t efuse_programming_revision = + (omc_reg >> EFUSE_PROGRAMMING_REVISION_SHIFT) & + EFUSE_PROGRAMMING_REVISION_MASK; + VLOG(1) << StringPrintf("e-fuse programming revision: %d", + efuse_programming_revision); + + if (options_.usb_enable_bulk_descriptors_from_device) { + VLOG(7) << StringPrintf("%s Enabling all descriptors", __func__); + RETURN_IF_ERROR(registers_->Write(usb_csr_offsets_.descr_ep, 0xFF)); + } else { + VLOG(7) << StringPrintf("%s Enabling only sc host interrupt descriptors", + __func__); + RETURN_IF_ERROR(registers_->Write(usb_csr_offsets_.descr_ep, 0xF0)); + } + + switch (options_.mode) { + case OperatingMode::kMultipleEndpointsHardwareControl: + case OperatingMode::kMultipleEndpointsSoftwareQuery: + VLOG(7) << StringPrintf("%s Enabling multiple EP mode", __func__); + RETURN_IF_ERROR(registers_->Write(usb_csr_offsets_.multi_bo_ep, 1)); + break; + + case OperatingMode::kSingleEndpoint: + VLOG(7) << StringPrintf("%s Enabling single EP mode", __func__); + RETURN_IF_ERROR(registers_->Write(usb_csr_offsets_.multi_bo_ep, 0)); + break; + + default: + return util::FailedPreconditionError("Unrecognized USB operating mode"); + } + + if ((!options_.usb_force_largest_bulk_in_chunk_size) && + (usb_device_->GetDeviceSpeed() == + UsbStandardCommands::DeviceSpeed::kHigh)) { + // If we know it's USB2 Highspeed (max bulk packet size 512B), and there is + // no option to force max chunk size, use 256B chunk size to limit packet + // length to 256B. This is a workaround for b/73181174 + VLOG(7) << StringPrintf("%s Setting 256B chunk for USB 2 High Speed", + __func__); + + // This is an optimiaztion for some host controllers, so everyone knows the + // response would be only 256 bytes long. Without this, host would need to + // identify the response as a "short" packet and hence ends transfer. Not + // all host controllers need this change. + cap_bulk_in_size_at_256_bytes_ = true; + + RETURN_IF_ERROR( + registers_->Write(usb_csr_offsets_.outfeed_chunk_length, 0x20)); + } else { + // Otherwise, use the largest chunk size (1KB) for max bulk packet size + // determined by the hardware. + VLOG(7) << StringPrintf("%s Setting 1KB chunk for bulk-ins", __func__); + + cap_bulk_in_size_at_256_bytes_ = false; + RETURN_IF_ERROR( + registers_->Write(usb_csr_offsets_.outfeed_chunk_length, 0x80)); + } + + return util::Status(); // OK. +} + +util::Status UsbDriver::RegisterAndEnableAllInterrupts() { + // TODO: Register interrupts to interrupt EP. + RETURN_IF_ERROR(fatal_error_interrupt_controller_->EnableInterrupts()); + RETURN_IF_ERROR(top_level_interrupt_manager_->EnableInterrupts()); + + return util::Status(); // OK +} + +util::Status UsbDriver::DisableAllInterrupts() { + RETURN_IF_ERROR(top_level_interrupt_manager_->DisableInterrupts()); + RETURN_IF_ERROR(fatal_error_interrupt_controller_->DisableInterrupts()); + + return util::Status(); // OK +} + +void UsbDriver::HandleEvent(const util::Status& status, + const UsbMlCommands::EventDescriptor& event_info) { + if (status.ok()) { + // TODO: analyze if there is any failure case we can recover from. + CHECK_OK(HandleDmaDescriptor( + event_info.tag, event_info.offset, event_info.length, + options_.usb_enable_bulk_descriptors_from_device)); + } else if (util::IsDeadlineExceeded(status)) { + VLOG(10) << StringPrintf("%s timed out, ignore.", __func__); + } else if (util::IsCancelled(status)) { + VLOG(10) << StringPrintf("%s cancelled, ignore.", __func__); + } else { + LOG(FATAL) << StringPrintf("%s failed. %s", __func__, + status.error_message().c_str()); + } +} + +util::Status UsbDriver::CheckHibError() { + // Indicates no HIB Fatal Error. + constexpr uint64 kHibErrorStatusNone = 0; + + ASSIGN_OR_RETURN(uint64 hib_error_status, + registers_->Read(hib_user_csr_offsets_.hib_error_status)); + if (hib_error_status == kHibErrorStatusNone) { + return util::Status(); // OK + } + + ASSIGN_OR_RETURN( + uint64 hib_first_error_status, + registers_->Read(hib_user_csr_offsets_.hib_first_error_status)); + + auto error_string = StringPrintf( + "HIB Error. hib_error_status = %016llx, hib_first_error_status = %016llx", + static_cast(hib_error_status), // NOLINT(runtime/int) + static_cast( // NOLINT(runtime/int) + hib_first_error_status)); + LOG(ERROR) << error_string; + return util::InternalError(error_string); +} + +void UsbDriver::HandleInterrupt( + const util::Status& status, + const UsbMlCommands::InterruptInfo& interrupt_info) { + if (status.ok()) { + VLOG(10) << StringPrintf("%s interrupt received.", __func__); + + constexpr uint32_t kFatalErrorInterruptMask = 1; + constexpr int kTopLevelInterruptBitShift = 1; + const uint32_t kTopLevelInterruptMask = + ((1 << top_level_interrupt_manager_->NumInterrupts()) - 1) + << kTopLevelInterruptBitShift; + if (interrupt_info.raw_data & kFatalErrorInterruptMask) { + VLOG(1) << StringPrintf("%s Fatal error interrupt received.", __func__); + CHECK_OK(CheckHibError()); + CHECK_OK(fatal_error_interrupt_controller_->ClearInterruptStatus(0)); + } + if ((interrupt_info.raw_data & kTopLevelInterruptMask) != 0) { + uint32_t top_level_interrupts = static_cast( + (interrupt_info.raw_data & kTopLevelInterruptMask) >> + kTopLevelInterruptBitShift); + + for (int id = 0; id < top_level_interrupt_manager_->NumInterrupts(); + ++id) { + const uint32_t mask = 1 << id; + if ((top_level_interrupts & mask) == mask) { + VLOG(1) << StringPrintf("%s Top level interrupt %d received.", + __func__, id); + CHECK_OK(top_level_interrupt_manager_->HandleInterrupt(id)); + } + } + } + + } else if (util::IsCancelled(status)) { + VLOG(10) << StringPrintf("%s cancelled, ignore.", __func__); + } else { + VLOG(1) << status.message(); + } +} + +uint32_t UsbDriver::GetCredits(UsbMlCommands::DescriptorTag tag) { + if (!registers_->Write32(apex_csr_offsets_.omc0_00, 0xffffffff).ok()) { + VLOG(1) << StringPrintf("%s write failed. silently assume 0 credit", + __func__); + return 0; + } + + auto query_result = registers_->Read(usb_csr_offsets_.ep_status_credit); + if (!query_result.status().ok()) { + VLOG(1) << StringPrintf("%s read failed. silently assume 0 credit", + __func__); + return 0; + } + const uint64_t gcb_credits = query_result.ValueOrDie(); + + constexpr uint32_t kCounterInBytes = 8; + constexpr uint32_t kCreditShift = 21; + constexpr uint32_t kCreditMask = (1ULL << kCreditShift) - 1; + + const uint32_t instructions = + static_cast((gcb_credits & kCreditMask) * kCounterInBytes); + const uint32_t input_activations = static_cast( + ((gcb_credits >> kCreditShift) & kCreditMask) * kCounterInBytes); + const uint32_t parameters = static_cast( + ((gcb_credits >> (kCreditShift * 2)) & kCreditMask) * kCounterInBytes); + + VLOG(10) << StringPrintf("%s credits: instructions %u, input %u, params %u", + __func__, instructions, input_activations, + parameters); + + switch (tag) { + case UsbMlCommands::DescriptorTag::kInstructions: + return instructions; + case UsbMlCommands::DescriptorTag::kInputActivations: + return input_activations; + case UsbMlCommands::DescriptorTag::kParameters: + return parameters; + default: + LOG(FATAL) << StringPrintf("%s unrecognized tag", __func__); + unreachable(); // NOLINT + } +} + +// TODO: breaks up this function according to functionality. +util::StatusOr UsbDriver::ProcessIo() { + TRACE_SCOPE("UsbDriver::ProcessIO"); + static constexpr int kNumBulkOutTags = 3; + static constexpr uint8_t tag_to_bulk_out_endpoint_id[kNumBulkOutTags] = { + UsbMlCommands::kInstructionsEndpoint, + UsbMlCommands::kInputActivationsEndpoint, + UsbMlCommands::kParametersEndpoint}; + int num_active_transfers = 0; + std::bitset tag_to_bulk_out_with_unsent_chunk; + + // Remove UsbIoRequest that are completed. + while (!io_requests_.empty()) { + const auto& io_request = io_requests_.front(); + if (!io_request.IsCompleted()) { + break; + } + + // If DMA descriptors are coming in, and hint is not yet matched. Consider + // it not completed. + if (options_.usb_enable_bulk_descriptors_from_device && + io_request.GetSourceAndMatchStatus() == + UsbIoRequest::SourceAndMatchStatus::kHintNotYetMatched) { + break; + } + + if (io_request.FromDmaHint()) { + CHECK_OK(dma_scheduler_.NotifyDmaCompletion(io_request.dma_info())); + } + + if (io_request.GetTag() == UsbMlCommands::DescriptorTag::kInterrupt0) { + TRACE_WITHIN_SCOPE("UsbDriver::ProcessIO::RequestCompletion"); + CHECK_OK(dma_scheduler_.NotifyRequestCompletion()); + HandleTpuRequestCompletion(); + } + VLOG(9) << "IO completed"; + io_requests_.pop_front(); + } + + // TODO: Remove this loop. + // As an intermediate step, IO requests are completely pulled out from a + // Request. Eventually, We should GetNextDma() only when we can perform DMA. + ASSIGN_OR_RETURN(auto* dma_info, dma_scheduler_.GetNextDma()); + while (dma_info) { + io_requests_.push_back(UsbIoRequest(dma_info)); + ASSIGN_OR_RETURN(dma_info, dma_scheduler_.GetNextDma()); + } + + // True if some libusb command has been issued and we should skip waiting on + // the completion queue. + bool is_task_state_changed = false; + + // All previous bulk out requests must be completed before a bulk in, and + // interrupt 0 request can be processed. + bool is_any_bulk_out_still_uncompleted = false; + bool is_any_bulk_in_still_uncompleted = false; + + for (auto& io_request : io_requests_) { + if (io_request.IsCompleted()) { + continue; + } + + if (io_request.GetTag() == UsbMlCommands::DescriptorTag::kInterrupt0) { + // Nothing to do for interrupts. + continue; + } + + auto io_type = io_request.GetType(); + const int tag = static_cast(io_request.GetTag()); + if (io_type == UsbIoRequest::Type::kBulkOut) { + // block further processing of any bulk in requests. + is_any_bulk_out_still_uncompleted = true; + + if (io_request.IsActive()) { + // simply increase the counter and proceed to see if we can fire another + // request for the next chunk. + num_active_transfers += io_request.GetActiveCounts( + options_.max_bulk_out_transfer_size_in_bytes); + } else { + if (options_.mode == OperatingMode::kMultipleEndpointsHardwareControl) { + // In multiple-ep hardware control mode, let's continue + // searching for a different tag to be sent out. It's never okay to + // interleve/mix chunks from different requests of the same tag. + if (tag_to_bulk_out_with_unsent_chunk[static_cast( + UsbMlCommands::DescriptorTag::kInstructions)]) { + // If there is any uncompleted instructions, break from the search. + break; + } else if (tag_to_bulk_out_with_unsent_chunk.count() == + (kNumBulkOutTags - 1)) { + // If all endpoints(tags) supported, other than instructions, are + // busy, break from the search. If instructions endpoint is busy, we + // already break from previous clause. + break; + } else if (tag_to_bulk_out_with_unsent_chunk[tag]) { + // If something sharing with my endpoint is busy, keep looking for + // something different. + continue; + } + } else { + if (tag_to_bulk_out_with_unsent_chunk.any()) { + // In other modes, especially for single-ep mode, continue searching + // is not necessary if any previous request has unsent chunk data + // waiting. This is because we only send out header once for each + // request. Note that it's okay to start sending the next chunk if + // the request is already active, as they share the same header. + // Also, it's okay to start processing a new request after the + // previous request has all its data pushed into the pipeline. + break; + } + } + } + + if (is_any_bulk_in_still_uncompleted) { + // Prevent any queuing of bulk-out after bulk-in, in single ep mode. + // It's not very safe to allow bulk-out after bulk-in in single ep mode, + // as the bulk-out could hog the internal data path and prevent bulk-in + // from completion. In multiple ep mode, the internal data path cannot + // be occupied for long time, and hence it's safe to queue any request. + if (options_.mode == OperatingMode::kSingleEndpoint) { + // Due to hardware limitation, bulk-out could delay completion of + // bulk-in till deadlock occurs. + VLOG(10) << StringPrintf( + "[%d-%d] all bulk in requests must be completed before " + "processing of bulk out can start, wait", + io_request.id(), tag); + break; + } + } else if (num_active_transfers >= options_.usb_max_num_async_transfers) { + VLOG(10) << StringPrintf( + "[%d-%d] number of concurrent transfers too high, wait " + "(%d >= %d)", + io_request.id(), tag, num_active_transfers, + options_.usb_max_num_async_transfers); + break; + } + + if (!io_request.HasNextChunk()) { + // There is nothing we can do for this request. All data has been put + // in transit. + continue; + } + + if (options_.mode == OperatingMode::kMultipleEndpointsSoftwareQuery) { + // TODO: add some mechansim to slowly poll for available + // credits. + // Setting this to true would cause unpleasant busy looping. + is_task_state_changed = true; + + // query for credits available. + uint32_t credits = GetCredits(io_request.GetTag()); + // proceed only if the credits are above some threashold. + if (credits <= options_.software_credits_lower_limit_in_bytes) { + VLOG(10) << StringPrintf( + "[%d-%d] available credits too low, wait (%u <= %u)", + io_request.id(), tag, credits, + options_.software_credits_lower_limit_in_bytes); + + // Stop further processing if credit for any endpoint is lower than + // the limit. + // TODO: allow a different endpoint to proceed. + break; + } + + // Clamp the transfer size with available credits. + uint32_t transfer_size = + std::min(options_.max_bulk_out_transfer_size_in_bytes, credits); + + const auto device_buffer = io_request.GetNextChunk(transfer_size); + auto host_buffer = address_space_.Translate(device_buffer).ValueOrDie(); + UsbMlCommands::ConstBuffer transfer_buffer(host_buffer.ptr(), + host_buffer.size_bytes()); + + ++num_active_transfers; + if (io_request.HasNextChunk()) { + // This request still has some data not sent over to the pipeline. + // Setting this to true prevents rquests of the same tag to start, as + // chunks from different requests of the same tag must not interleve. + tag_to_bulk_out_with_unsent_chunk[tag] = true; + } + + // To make sure the query result for credits is accurate, we have to + // use sync transfer. Because we only send data according to credits + // available, there is no way we could get a timeout error. + util::Status status = usb_device_->BulkOutTransfer( + tag_to_bulk_out_endpoint_id[tag], transfer_buffer, __func__); + if (status.ok()) { + io_request.NotifyTransferComplete(transfer_size); + VLOG(10) << StringPrintf("[%d-%d] bulk out for %u bytes done", + io_request.id(), tag, transfer_size); + } else { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step would + // be resetting the device. + LOG(FATAL) << StringPrintf( + "[%d-%d] bulk out for %u bytes failed. Abort. %s", + io_request.id(), tag, transfer_size, status.ToString().c_str()); + } + } else if (options_.mode == + OperatingMode::kMultipleEndpointsHardwareControl) { + is_task_state_changed = true; + + const auto device_buffer = io_request.GetNextChunk( + options_.max_bulk_out_transfer_size_in_bytes); + auto host_buffer = address_space_.Translate(device_buffer).ValueOrDie(); + UsbMlCommands::ConstBuffer transfer_buffer(host_buffer.ptr(), + host_buffer.size_bytes()); + uint32_t transfer_size = static_cast(transfer_buffer.size()); + + ++num_active_transfers; + if (io_request.HasNextChunk()) { + // This request still has some data not sent over to the pipeline. + // Setting this to true prevents rquests of the same tag to start, as + // chunks from different requests of the same tag must not interleve. + tag_to_bulk_out_with_unsent_chunk[tag] = true; + } + + util::Status async_request_status = usb_device_->AsyncBulkOutTransfer( + tag_to_bulk_out_endpoint_id[tag], transfer_buffer, + [this, &io_request, tag, transfer_size](util::Status status) { + // Inject a functor into a completion queue driven by the worker + // thread. Note that the reference to io_request could have been + // invalidated when the async transfer is cancelled. + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push([&io_request, tag, status, transfer_size] { + // Starting from here is an functor which would be executed + // within the worker thread context, after the async transfer + // has been completed. + if (status.ok()) { + io_request.NotifyTransferComplete(transfer_size); + VLOG(10) << StringPrintf("[%d-%d] bulk out for %u bytes done", + io_request.id(), tag, transfer_size); + } else { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step + // would be resetting the device. + LOG(FATAL) << StringPrintf( + "[%d-%d] bulk out failed. Abort. %s", io_request.id(), + tag, status.ToString().c_str()); + } + }); + driver_state_changed_.notify_all(); + }, + __func__); + + if (!async_request_status.ok()) { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step would + // be resetting the device. + LOG(FATAL) << StringPrintf( + "[%d-%d] async transfer out for %u bytes failed. Abort. %s", + io_request.id(), tag, transfer_size, + async_request_status.ToString().c_str()); + } + } else if (options_.mode == OperatingMode::kSingleEndpoint) { + is_task_state_changed = true; + + if (!io_request.IsActive() && !io_request.IsCompleted() && + !io_request.IsHeaderSent()) { + // Prepare the header with full data size. + // add one extra count for the header transfer. + ++num_active_transfers; + + VLOG(10) << StringPrintf("%s [%d-%d] bulk out header", __func__, + io_request.id(), tag); + + io_request.SetHeader(usb_device_->PrepareHeader( + io_request.GetTag(), io_request.GetBuffer().size_bytes())); + + util::Status async_request_status = usb_device_->AsyncBulkOutTransfer( + UsbMlCommands::kSingleBulkOutEndpoint, + UsbMlCommands::ConstBuffer(io_request.header()), + [this, &io_request, tag](util::Status status) { + // Inject a functor into a completion queue driven by the worker + // thread. Note that the reference to io_request could have been + // invalidated when the async transfer is cancelled. + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push([&io_request, tag, status] { + // Starting from here is an functor which would be executed + // within the worker thread context, after the async transfer + // has been completed. + if (status.ok()) { + VLOG(10) << StringPrintf("[%d-%d] bulk out for header done", + io_request.id(), tag); + } else { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step + // would be resetting the device. + LOG(FATAL) << StringPrintf( + "[%d-%d] bulk out for header failed. Abort. %s", + io_request.id(), tag, status.ToString().c_str()); + } + }); + driver_state_changed_.notify_all(); + }, + __func__); + + if (!async_request_status.ok()) { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step would + // be resetting the device. + LOG(FATAL) << StringPrintf( + "[%d-%d] bulk out for header failed. Abort. %s", + io_request.id(), tag, async_request_status.ToString().c_str()); + } + } + + // Send the actual data in chunks. + const auto device_buffer = io_request.GetNextChunk( + options_.max_bulk_out_transfer_size_in_bytes); + auto host_buffer = address_space_.Translate(device_buffer).ValueOrDie(); + UsbMlCommands::ConstBuffer transfer_buffer(host_buffer.ptr(), + host_buffer.size_bytes()); + uint32_t transfer_size = static_cast(transfer_buffer.size()); + + ++num_active_transfers; + if (io_request.HasNextChunk()) { + // This request still has some data not sent over to the pipeline. + // Setting this to true prevents rquests of the same tag to start, as + // chunks from different requests of the same tag must not interleve. + tag_to_bulk_out_with_unsent_chunk[tag] = true; + } + + util::Status async_request_status = usb_device_->AsyncBulkOutTransfer( + UsbMlCommands::kSingleBulkOutEndpoint, transfer_buffer, + [this, &io_request, tag, transfer_size](util::Status status) { + // Inject a functor into a completion queue driven by the worker + // thread. Note that the reference to io_request could have been + // invalidated when the async transfer is cancelled. + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push([&io_request, tag, status, transfer_size] { + // Starting from here is an functor which would be executed + // within the worker thread context, after the async transfer + // has been completed. + if (status.ok()) { + io_request.NotifyTransferComplete(transfer_size); + VLOG(10) << StringPrintf( + "%s [%d-%d] bulk out for %u bytes done", __func__, + io_request.id(), tag, transfer_size); + } else { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step + // would be resetting the device. + LOG(FATAL) + << StringPrintf("transfer on tag %d failed. Abort. %s", + tag, status.ToString().c_str()); + } + }); + driver_state_changed_.notify_all(); + }, + __func__); + + if (!async_request_status.ok()) { + // TODO: terminate the task early, as there is no + // chance we can continue. The more reasonable next step would + // be resetting the device. + LOG(FATAL) << StringPrintf( + "%s [%d-%d] async transfer out failed. Abort. %s", __func__, + io_request.id(), tag, async_request_status.ToString().c_str()); + } + } + } else if (io_type == UsbIoRequest::Type::kBulkIn) { + // If queuing is enabled, bulk-in requests are handled similar to + // interrupt and dma descriptors. + if (options_.usb_enable_queued_bulk_in_requests) { + // Skip if any previous bulk-in request is still incomplete. This is + // because all bulk-in requests have to be serialized. + if (is_any_bulk_in_still_uncompleted) { + continue; + } + + // Walk through filled buffer queue. + while (!filled_bulk_in_buffers_.empty()) { + // We're about to change the state of io requests. + // This flag indicates we need to call ProcessIo() again. + is_task_state_changed = true; + + // We're getting a reference, as we're directly modifying + // begin_offset. + FilledBulkInInfo& filled_info = filled_bulk_in_buffers_.front(); + + const Buffer& buffer = bulk_in_buffers_[filled_info.buffer_index]; + + const size_t available_data_size_bytes = + filled_info.end_offset - filled_info.begin_offset; + + auto device_buffer = io_request.GetNextChunk(); + + auto host_buffer = + address_space_.Translate(device_buffer).ValueOrDie(); + + const size_t requested_size_bytes = host_buffer.size_bytes(); + + const size_t transferred_bytes = + std::min(available_data_size_bytes, requested_size_bytes); + + memcpy(host_buffer.ptr(), buffer.ptr() + filled_info.begin_offset, + transferred_bytes); + + io_request.NotifyTransferComplete(transferred_bytes); + + if (available_data_size_bytes <= requested_size_bytes) { + VLOG(10) << StringPrintf( + "[%d-%d] bulk in for %zu bytes has yielded %zu bytes from " + "index [%d]", + io_request.id(), tag, requested_size_bytes, + available_data_size_bytes, filled_info.buffer_index); + + // We've depleted the buffer. Return it to available queue. + available_bulk_in_buffers_.push(filled_info.buffer_index); + filled_bulk_in_buffers_.pop(); + + if (io_request.IsCompleted()) { + // There is no need to check the next buffer, as we've just + // completed this io_request. + break; + } + } else { + VLOG(10) << StringPrintf( + "[%d-%d] bulk in for %zu bytes has yielded %zu bytes " + "(OVERFLOW) from index [%d]", + io_request.id(), tag, requested_size_bytes, + available_data_size_bytes, filled_info.buffer_index); + + filled_info.begin_offset += requested_size_bytes; + + // We just completed this io_request, stop iterating through + // buffers. + break; + } + } + + if (!io_request.IsCompleted()) { + // This flag would prevent further bulk-in request in all modes, and + // further bulk-out in single-ep mode. + is_any_bulk_in_still_uncompleted = true; + } + // Continue to the next io_request. + continue; + } + + if (!options_.usb_enable_overlapping_bulk_in_and_out && + is_any_bulk_out_still_uncompleted) { + VLOG(10) << StringPrintf( + "[%d-%d] configured to start only after all " + "bulk-out requests complete, wait", + io_request.id(), tag); + break; + } else if (num_active_transfers >= options_.usb_max_num_async_transfers) { + VLOG(10) << StringPrintf( + "[%d-%d] number of concurrent transfers too high, wait " + "(%d >= %d)", + io_request.id(), tag, num_active_transfers, + options_.usb_max_num_async_transfers); + break; + } else if (io_request.IsActive()) { + ++num_active_transfers; + // Still transferring data in. Break from the loop. + VLOG(10) << StringPrintf( + "[%d-%d] this bulk in request is still active, wait", + io_request.id(), tag); + break; + } else { + is_task_state_changed = true; + is_any_bulk_in_still_uncompleted = true; + + auto device_buffer = (cap_bulk_in_size_at_256_bytes_) + ? io_request.GetNextChunk(256) + : io_request.GetNextChunk(); + + auto host_buffer = address_space_.Translate(device_buffer).ValueOrDie(); + UsbMlCommands::MutableBuffer transfer_buffer(host_buffer.ptr(), + host_buffer.size_bytes()); + uint32_t transfer_size = static_cast(transfer_buffer.size()); + + VLOG(10) << StringPrintf("[%d-%d] bulk in for %zu bytes", + io_request.id(), tag, transfer_buffer.size()); + + ++num_active_transfers; + + util::Status async_request_status = usb_device_->AsyncBulkInTransfer( + UsbMlCommands::kBulkInEndpoint, transfer_buffer, + [this, &io_request, tag, transfer_size]( + util::Status status, size_t num_bytes_transferred) { + // Inject a functor into a completion queue driven by the worker + // thread. Note that the reference to io_request could have been + // invalidated when the async transfer is cancelled. + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push([&io_request, status, num_bytes_transferred, + tag, transfer_size] { + // Starting from here is an functor which would be executed + // within the worker thread context, after the async + // transfer has been completed. + if (status.ok()) { + io_request.NotifyTransferComplete(num_bytes_transferred); + VLOG(10) << StringPrintf( + "[%d-%d] bulk in for %u bytes has yielded %zu bytes", + io_request.id(), tag, transfer_size, + num_bytes_transferred); + } else { + // Note that the reference to io_request could have been + // invalidated when the async transfer is cancelled. + // TODO: fail the task and allow reset of the + // chip. + LOG(FATAL) + << StringPrintf("%s transfer in failed. Abort. %s", + __func__, status.ToString().c_str()); + } + }); + driver_state_changed_.notify_all(); + }, + __func__); + + if (!async_request_status.ok()) { + LOG(FATAL) << StringPrintf("[%d-%d] transfer in failed. Abort", + io_request.id(), tag); + } + + // Break from further processing if there is any bulk-in request which + // has not been completed. + break; + } + } else { + LOG(FATAL) << StringPrintf("%s [%d-%d] unexpected request type", __func__, + io_request.id(), tag); + } + } + + return is_task_state_changed; +} + +util::Status UsbDriver::HandleDmaDescriptor(UsbMlCommands::DescriptorTag tag, + uint64_t device_virtual_address, + uint32_t size_bytes, + bool bulk_events_enabled) { + DeviceBuffer buffer(device_virtual_address, size_bytes); + VLOG(10) << StringPrintf( + "Digesting descriptor from device tag[%d], data[0x%llx], size[%zu]", + static_cast(tag), + static_cast( // NOLINT(runtime/int) + buffer.device_address()), + buffer.size_bytes()); + + // First check whether if there is any matching hint. + for (auto& io_request : io_requests_) { + const auto hint_tag = io_request.GetTag(); + const auto hint_type = io_request.GetType(); + const auto hint_buffer = io_request.GetBuffer(); + const auto hint_status = io_request.GetSourceAndMatchStatus(); + + if (hint_status == UsbIoRequest::SourceAndMatchStatus::kSubmittedByDevice || + hint_status == + UsbIoRequest::SourceAndMatchStatus::kHintAlreadyMatched) { + continue; + } + + if (hint_tag == UsbMlCommands::DescriptorTag::kInstructions) { + // Device never sends DMA descriptor for instructions, consider them as + // always matched. + io_request.SetMatched(); + continue; + } + + if (!bulk_events_enabled && + hint_type != UsbIoRequest::Type::kScHostInterrupt) { + // Only in-band scalar core interrupts can be matched. + continue; + } + + if (tag != hint_tag) { + // If DMA descriptor from device does not match hint, then it is a new + // DMA. + break; + } + + if (hint_tag != UsbMlCommands::DescriptorTag::kInterrupt0 && + hint_buffer != buffer) { + continue; + } + + io_request.SetMatched(); + return util::Status(); // OK. + } + + // If there is no matching hint, then USB driver should process the + // descriptor. + switch (tag) { + case UsbMlCommands::DescriptorTag::kInputActivations: + case UsbMlCommands::DescriptorTag::kParameters: + VLOG(9) << "Received new bulk out command"; + io_requests_.push_back(UsbIoRequest( + io_requests_.back().id(), UsbIoRequest::Type::kBulkOut, tag, buffer)); + break; + + case UsbMlCommands::DescriptorTag::kOutputActivations: + VLOG(9) << "Received new bulk in command"; + io_requests_.push_back(UsbIoRequest( + io_requests_.back().id(), UsbIoRequest::Type::kBulkIn, tag, buffer)); + break; + + case UsbMlCommands::DescriptorTag::kInterrupt0: + case UsbMlCommands::DescriptorTag::kInterrupt1: + case UsbMlCommands::DescriptorTag::kInterrupt2: + case UsbMlCommands::DescriptorTag::kInterrupt3: + VLOG(9) << "Received new interrupt"; + io_requests_.push_back(UsbIoRequest(io_requests_.back().id(), tag)); + break; + + // Instruction descriptor is never sent from device. + case UsbMlCommands::DescriptorTag::kInstructions: + case UsbMlCommands::DescriptorTag::kUnknown: + LOG(FATAL) << StringPrintf("Unknown descriptor from device"); + } + + return util::Status(); // OK. +} + +void UsbDriver::HandleQueuedBulkIn(const util::Status& status, int buffer_index, + size_t num_bytes_transferred) { + if (status.ok()) { + // Enqueue the filled buffer with actual data size. + filled_bulk_in_buffers_.push( + FilledBulkInInfo{buffer_index, 0, num_bytes_transferred}); + + VLOG(1) << StringPrintf("bulk in %zu bytes from buffer index [%d]", + num_bytes_transferred, buffer_index); + } else { + // num_bytes_transferred is not valid. Just return the buffer to available + // queue. + + available_bulk_in_buffers_.push(buffer_index); + + if (!IsCancelled(status) && !IsDeadlineExceeded(status)) { + // TODO: convert to driver error. + LOG(FATAL) << StringPrintf("%s transfer in failed. %s", __func__, + status.ToString().c_str()); + } + } +} + +void UsbDriver::WorkerThreadFunc() { + VLOG(7) << StringPrintf("%s starting worker thread", __func__); + TRACE_START_THREAD("UsbDriverWorkerThread"); + + // Types of background operations that need to be triggered in parallel to IO + // request handling. + enum BackgroudOperations { + kReadOutputActivations = 0, + kReadEvent, + kReadInterrupt, + kNumBackgroundOperations, + }; + + // Current background operations. + std::bitset background_ops; + + do { + // Lock the driver and check for state. + StdCondMutexLock state_lock(&mutex_); + + VLOG(10) << StringPrintf( + "%s dispatching %d callback events in worker thread", __func__, + static_cast(QueueSize(&callback_queue_, &callback_mutex_))); + + while (!IsQueueEmpty(&callback_queue_, &callback_mutex_)) { + // Note the queue is not locked when the callback executes. This is + // intentional, as it simplifies design of interrupt handlers, allowing + // sync CSR access from the handlers. + QueuePop(&callback_queue_, &callback_mutex_)(); + } + + bool reevaluation_needed = false; + + if (state_ == kClosing) { + // If all buffers are available, flag that we're not reading output + // activations at this moment. (So it's okay to close the driver.) + if (available_bulk_in_buffers_.size() == + options_.usb_bulk_in_queue_capacity) { + background_ops[kReadOutputActivations] = false; + + VLOG(10) << "All bulk-in buffers are available"; + } + + if (background_ops.any() || !dma_scheduler_.IsEmpty()) { + VLOG(7) << "Driver is closing. Wait for async operations to complete."; + } else { + // Terminate the worker thread. + VLOG(7) + << "Driver is closing, and all async operations have completed."; + break; + } + } else if (state_ == kPaused) { + VLOG(7) << "Driver is paused. Do not initiate further device operations."; + } else { + // Check if any of the async operations needs to be re-installed. + if (!background_ops[kReadEvent]) { + VLOG(7) << StringPrintf("%s Re-installing event reader", __func__); + reevaluation_needed = true; + background_ops[kReadEvent] = true; + util::Status status = usb_device_->AsyncReadEvent( + [this, &background_ops]( + util::Status status, + const UsbMlCommands::EventDescriptor& event_info) { + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push([this, &background_ops, status, event_info] { + // Note this wrapping confuses thread safety analyzer + HandleEvent(status, event_info); + background_ops[kReadEvent] = false; + }); + driver_state_changed_.notify_all(); + }); + if (!status.ok()) { + VLOG(1) << StringPrintf("%s AsyncReadEvent failed:", __func__) + << status; + break; + } + } + if (!background_ops[kReadInterrupt]) { + VLOG(7) << StringPrintf("%s Re-installing interrupt reader", __func__); + background_ops[kReadInterrupt] = true; + reevaluation_needed = true; + util::Status status = usb_device_->AsyncReadInterrupt( + [this, &background_ops]( + util::Status status, + const UsbMlCommands::InterruptInfo& interrupt_info) { + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push( + [this, &background_ops, status, interrupt_info] { + // Note this wrapping confuses thread safety analyzer + HandleInterrupt(status, interrupt_info); + background_ops[kReadInterrupt] = false; + }); + driver_state_changed_.notify_all(); + }); + if (!status.ok()) { + VLOG(1) << StringPrintf("%s AsyncReadInterrupt failed:", __func__) + << status; + break; + } + } + + if (options_.usb_enable_queued_bulk_in_requests) { + while (!available_bulk_in_buffers_.empty()) { + const int buffer_index = available_bulk_in_buffers_.front(); + available_bulk_in_buffers_.pop(); + + VLOG(7) << StringPrintf( + "%s Installing bulk-in reader. buffer index [%d]", __func__, + buffer_index); + + background_ops[kReadOutputActivations] = true; + reevaluation_needed = true; + + UsbMlCommands::MutableBuffer transfer_buffer( + bulk_in_buffers_[buffer_index].ptr(), + bulk_in_buffers_[buffer_index].size_bytes()); + + // Clear data to prevent data leakage from request to request. + memset(transfer_buffer.data(), 0, transfer_buffer.size()); + + util::Status async_request_status = usb_device_->AsyncBulkInTransfer( + UsbMlCommands::kBulkInEndpoint, transfer_buffer, + [this, buffer_index](util::Status status, + size_t num_bytes_transferred) { + // This functor is executed directly from underlying completion + // callback thread. We need to transfer it to be processed in + // WorkerThreadFunc by pushing a new functor to the callback + // queue. + StdMutexLock queue_lock(&callback_mutex_); + callback_queue_.push( + [this, status, buffer_index, num_bytes_transferred] { + // This function is executed from WorkerThreadFunc. + // Note this wrapping confuses thread safety analyzer. + HandleQueuedBulkIn(status, buffer_index, + num_bytes_transferred); + }); + // Notify the worker thread to work on the callback queue. + driver_state_changed_.notify_all(); + }, + __func__); + + if (!async_request_status.ok()) { + // TODO: convert to some driver error. + LOG(FATAL) << "Bulk-in failed. Abort"; + } + } + } + + reevaluation_needed = ProcessIo().ValueOrDie(); + + // TODO: Enter kPaused state when dma_scheduler_.IsEmpty(). Any + // new task should kick the driver back to kOpen state. Note this is in + // contradiction to the plan to remove state in USB driver. + } + + if (reevaluation_needed) { + VLOG(10) << StringPrintf("%s re-evaluation is needed", __func__); + } else { + StdCondMutexLock queue_lock(&callback_mutex_); + + Lock2 unlock_both(state_lock, queue_lock); + + if (callback_queue_.empty()) { + VLOG(10) << StringPrintf("%s waiting on state change", __func__); + + // Release the lock and wait for further state change. + driver_state_changed_.wait(unlock_both); + + VLOG(10) << StringPrintf("%s driver state change detected", __func__); + } else { + VLOG(10) << StringPrintf("%s callback event available. skip waiting", + __func__); + } + } + } while (true); + + VLOG(7) << StringPrintf("%s leaving worker thread", __func__); +} + +util::StatusOr> +UsbDriver::CreateRawUsbDeviceWithRetry() { + TRACE_SCOPE("UsbDriver::CreateRawUsbDeviceWithRetry"); + util::Status result; + for (int i = 0; i < kMaxNumOfRetryAfterReset; ++i) { + TRACE_SCOPE("UsbDriver::CreateRawUsbDeviceWithRetry:try"); + + // Wait even for the first time, before opening the raw device. + // We found it seems to reduce chances of transfer errors in long-running + // back-to-back tests. + // TODO: revisit after the connection issue has been resolved. + { + TRACE_SCOPE("UsbDriver::CreateRawUsbDeviceWithRetry:Microsleep"); + Microsleep(kSleepTimeMicroSecondsBeforeRetry); + } + + // Try to open the raw device. + auto raw_device_or_error = device_factory_(); + + // Return early if we get an OK. + result = raw_device_or_error.status(); + if (result.ok()) { + return raw_device_or_error; + } + } + return result; +} + +util::Status UsbDriver::OpenMlUsbDevice() { + TRACE_SCOPE("UsbDriver::OpenMlUsbDevice"); + + VLOG(7) << "Opening device expecting application mode"; + + ASSIGN_OR_RETURN(auto raw_usb_device, CreateRawUsbDeviceWithRetry()); + + usb_device_ = gtl::MakeUnique(std::move(raw_usb_device), + options_.usb_timeout_millis); + + return usb_device_ ? util::Status() + : util::UnknownError("Failed to create ML device"); +} + +util::Status UsbDriver::PrepareUsbDevice() { + TRACE_SCOPE("UsbDriver::PrepareUsbDevice"); + + // 1) Send DFU Detach command if already in application mode + // 2) USB Reset + // 3) Perform DFU + // 4) USB Reset + std::unique_ptr raw_usb_device; + + VLOG(7) << "Open device and check if DFU is needed"; + + ASSIGN_OR_RETURN(raw_usb_device, CreateRawUsbDeviceWithRetry()); + + auto dfu_device = gtl::MakeUnique( + std::move(raw_usb_device), options_.usb_timeout_millis); + + ASSIGN_OR_RETURN(UsbDfuCommands::DeviceDescriptor device_desc, + dfu_device->GetDeviceDescriptor()); + + // Timeout before DFU Detach expires. + constexpr int kShortTimeoutMillis = 100; + bool expect_app_mode_after_reset = false; + + if ((device_desc.vendor_id == kTargetAppVendorId) && + (device_desc.product_id == kTargetAppProductId)) { + if (options_.usb_always_dfu) { + // Device is in app mode, send DFU Detach command. + VLOG(7) << "Device is in application mode, sending DFU Detach"; + + constexpr int kDfuInterface = 0; + RETURN_IF_ERROR(dfu_device->ClaimInterface(kDfuInterface)); + RETURN_IF_ERROR(dfu_device->DfuDetach(kShortTimeoutMillis)); + + expect_app_mode_after_reset = false; + } else { + // Device is in app mode, we're done. + VLOG(7) << "Device is already in application mode, skipping DFU"; + expect_app_mode_after_reset = true; + } + } else if ((device_desc.vendor_id == kTargetDfuVendorId) && + (device_desc.product_id == kTargetDfuProductId)) { + // Do nothing. + expect_app_mode_after_reset = false; + VLOG(7) << "Device is in DFU mode"; + } else { + return util::FailedPreconditionError("Unrecognized USB Vendor/Product ID"); + } + + VLOG(7) << "Resetting device"; + + // Close with USB Reset no matter which mode we're in. + RETURN_IF_ERROR( + dfu_device->Close(UsbDfuCommands::CloseAction::kGracefulPortReset)); + + if (expect_app_mode_after_reset) { + return OpenMlUsbDevice(); + } + + VLOG(7) << "Opening device expecting DFU mode"; + + // Try to open again. + ASSIGN_OR_RETURN(raw_usb_device, CreateRawUsbDeviceWithRetry()); + + dfu_device = gtl::MakeUnique(std::move(raw_usb_device), + options_.usb_timeout_millis); + + // Download firmware, and then upload for verification. + if (!options_.usb_firmware_image.empty()) { + VLOG(7) << "DFU with supplied firmware image"; + + // Use firmware image supplied. + RETURN_IF_ERROR(UsbUpdateDfuDevice(dfu_device.get(), + options_.usb_firmware_image, + /*skip_verify*/ false)); + } else { + // Use firmware image built-in. + VLOG(7) << "DFU with built-in firmware image"; + + const uint8* dfu_firmware = nullptr; + size_t dfu_firmware_size = 0; + switch (options_.mode) { + case OperatingMode::kMultipleEndpointsHardwareControl: + case OperatingMode::kMultipleEndpointsSoftwareQuery: + dfu_firmware = apex_latest_multi_ep; + dfu_firmware_size = apex_latest_multi_ep_len; + break; + + case OperatingMode::kSingleEndpoint: + dfu_firmware = apex_latest_single_ep; + dfu_firmware_size = apex_latest_single_ep_len; + break; + + default: + return util::FailedPreconditionError("Unrecognized operating mode"); + } + + RETURN_IF_ERROR(UsbUpdateDfuDevice( + dfu_device.get(), + UsbDeviceInterface::ConstBuffer( + reinterpret_cast(dfu_firmware), dfu_firmware_size), + /*skip_verify=*/ false)); + } + + VLOG(7) << "Resetting device"; + // Reset to trigger switching to application mode. + RETURN_IF_ERROR( + dfu_device->Close(UsbDfuCommands::CloseAction::kGracefulPortReset)); + + return OpenMlUsbDevice(); +} + +util::Status UsbDriver::DoOpen(bool debug_mode) { + TRACE_SCOPE("UsbDriver::DoOpen"); + + StdMutexLock state_lock(&mutex_); + RETURN_IF_ERROR(ValidateState(/*expected_state=*/kClosed)); + + if (options_.usb_enable_queued_bulk_in_requests) { + if (!options_.usb_enable_overlapping_bulk_in_and_out) { + return util::FailedPreconditionError( + "Overlapping bulk-in/out must be enabled for queued bulk-in " + "feature"); + } + + constexpr unsigned int k1kBMask = 1024 - 1; + if (options_.usb_bulk_in_max_chunk_size_in_bytes & k1kBMask) { + return util::OutOfRangeError( + "Bulk-in buffer max chunk size must be 1024-byte aligned"); + } + + if (options_.usb_bulk_in_queue_capacity <= 0) { + return util::OutOfRangeError("Bulk-in queue capacity must be positive"); + } + } else { + options_.usb_bulk_in_queue_capacity = 0; + } + + if (device_factory_) { + RETURN_IF_ERROR(PrepareUsbDevice()); + } else { + // No device factory is provided. An instance must already be supplied. + if (usb_device_ == nullptr) { + return util::FailedPreconditionError( + "Either device factory or device instance must be supplied"); + } + } + + switch (usb_device_->GetDeviceSpeed()) { + case UsbStandardCommands::DeviceSpeed::kLow: + return util::FailedPreconditionError("USB Low speed is not supported"); + + case UsbStandardCommands::DeviceSpeed::kFull: + case UsbStandardCommands::DeviceSpeed::kHigh: + if (options_.usb_fail_if_slower_than_superspeed) { + return util::FailedPreconditionError( + "Connection speed is too slow, fail."); + } else if (options_.mode != OperatingMode::kSingleEndpoint) { + return util::FailedPreconditionError( + "Connection speed is incompatible with operating mode, fail"); + } + break; + + case UsbStandardCommands::DeviceSpeed::kSuper: + break; + + case UsbStandardCommands::DeviceSpeed::kUnknown: + default: + VLOG(7) << "Connection speed is unknown, ignore speed constraint"; + break; + } + + constexpr int kMlInterface = 0; + RETURN_IF_ERROR(usb_device_->ClaimInterface(kMlInterface)); + + RETURN_IF_ERROR(registers_->Open(usb_device_.get())); + + RETURN_IF_ERROR(top_level_handler_->Open()); + auto top_level_handler_closer = + MakeCleanup([this] { CHECK_OK(top_level_handler_->Close()); }); + + // Disable clock gate and reset GCB for clean state. + RETURN_IF_ERROR(top_level_handler_->DisableSoftwareClockGate()); + RETURN_IF_ERROR(top_level_handler_->DisableHardwareClockGate()); + RETURN_IF_ERROR(top_level_handler_->EnableReset()); + + // Quit from reset mode before accessing the chip. + RETURN_IF_ERROR(top_level_handler_->QuitReset()); + RETURN_IF_ERROR(top_level_handler_->EnableHardwareClockGate()); + + RETURN_IF_ERROR(InitializeChip()); + if (!debug_mode) { + // Move all subsystems to Run state. + RETURN_IF_ERROR(run_controller_->DoRunControl(RunControl::kMoveToRun)); + } + + RETURN_IF_ERROR(RegisterAndEnableAllInterrupts()); + + if (cap_bulk_in_size_at_256_bytes_) { + constexpr size_t k256Bytes = 256; + if (options_.usb_bulk_in_max_chunk_size_in_bytes > k256Bytes) { + options_.usb_bulk_in_max_chunk_size_in_bytes = k256Bytes; + + VLOG(7) << "Reducing bulk-in request size to 256 bytes for USB2"; + } + } + + for (int i = 0; i < options_.usb_bulk_in_queue_capacity; ++i) { + auto chunk = DoMakeBuffer(options_.usb_bulk_in_max_chunk_size_in_bytes); + if (!chunk.IsValid()) { + return util::ResourceExhaustedError( + "Bulk-in buffer chunk allocation failure"); + } + + // Save the Buffer object into a container, so it will be destroyed when + // driver destructs. + bulk_in_buffers_.push_back(chunk); + + // Save the index of available Buffer into the queue. + available_bulk_in_buffers_.push(i); + } + + // DMA scheduler. + RETURN_IF_ERROR(dma_scheduler_.Open()); + auto dma_scheduler_closer = MakeCleanup([this] { + CHECK_OK(dma_scheduler_.Close(api::Driver::ClosingMode::kGraceful)); + }); + + worker_thread_ = std::thread([this] { WorkerThreadFunc(); }); + + // On-Chip DRAM allocator. + RETURN_IF_ERROR(dram_allocator_->Open()); + + // All good. Move state to open. + RETURN_IF_ERROR(SetState(kOpen)); + + // Release cleanup functions. + dma_scheduler_closer.release(); + top_level_handler_closer.release(); + + return util::Status(); // OK +} + +util::Status UsbDriver::DoClose(bool in_error, api::Driver::ClosingMode mode) { + TRACE_SCOPE("UsbDriver::DoClose"); + + if (mode != api::Driver::ClosingMode::kGraceful) { + LOG(WARNING) << "Only graceful closing mode is currently supported in USB " + "driver; forcing to graceful"; + mode = api::Driver::ClosingMode::kGraceful; + } + + { + StdMutexLock state_lock(&mutex_); + RETURN_IF_ERROR(ValidateStates({kOpen, kPaused})); + + // Note our intention to close. Clocking gating is disabled here. + RETURN_IF_ERROR(SetState(kClosing)); + } + + worker_thread_.join(); + + // All good. Shut down stuff. This is best effort. So if things starts + // failing, keep going and try cleaning up as much as we can. + + RETURN_IF_ERROR(dma_scheduler_.Close(mode)); + RETURN_IF_ERROR(DisableAllInterrupts()); + RETURN_IF_ERROR(UnmapAllParameters()); + RETURN_IF_ERROR(run_controller_->DoRunControl(RunControl::kMoveToHalt)); + RETURN_IF_ERROR(top_level_handler_->EnableReset()); + RETURN_IF_ERROR(registers_->Close()); + RETURN_IF_ERROR(dram_allocator_->Close()); + + // Deallocate all bulk-in buffers. This is not absolutely necessary, but it's + // better to have a clean slate for the next Open. + bulk_in_buffers_.clear(); + + // Flush available buffers queue, marking we have no any buffer available. + while (!available_bulk_in_buffers_.empty()) { + available_bulk_in_buffers_.pop(); + } + + // All buffers should have been released, as all lsusb request should have + // been canceled. + CHECK(filled_bulk_in_buffers_.empty()); + + // Release ownership to the USB device instance. + usb_device_.reset(); + + // Finalize. + { + StdMutexLock state_lock(&mutex_); + RETURN_IF_ERROR(SetState(kClosed)); + } + + return util::Status(); // OK +} + +util::Status UsbDriver::DoCancelAndWaitRequests(bool in_error) { + RETURN_IF_ERROR(dma_scheduler_.CancelPendingRequests()); + if (!in_error) { + RETURN_IF_ERROR(dma_scheduler_.WaitActiveRequests()); + } + return util::Status(); // OK +} + +Buffer UsbDriver::DoMakeBuffer(size_t size_bytes) const { + Buffer buffer = allocator_->MakeBuffer(size_bytes); + + if (buffer.IsValid()) { + // Clear data to prevent data leakage from request to request. + memset(buffer.ptr(), 0, buffer.size_bytes()); + } + return buffer; +} + +util::StatusOr UsbDriver::DoMapBuffer( + const Buffer& buffer, DmaDirection direction) { + if (buffer.IsValid()) { + ASSIGN_OR_RETURN(auto device_buffer, address_space_.MapMemory(buffer)); + // TODO : this is dangerous: the std::bind captures a raw pointer to + // the underlying object of the unique_ptr'd address space. + // This will break if executable registry outlives address space in the + // driver. + return MappedDeviceBuffer( + device_buffer, std::bind(&NopAddressSpace::UnmapMemory, &address_space_, + std::placeholders::_1)); + } + + return MappedDeviceBuffer(); +} + +util::StatusOr> UsbDriver::DoCreateRequest( + const std::shared_ptr parent_request, + const ExecutableReference* executable_ref, TpuRequest::RequestType type) { + StdMutexLock lock(&mutex_); + RETURN_IF_ERROR(ValidateStates({kOpen})); + + // TODO: find a way to mix models, switching on and off descriptors + // on the fly. + if (!options_.usb_enable_bulk_descriptors_from_device) { + // If we disable bulk in/out descriptors from device, the hint must be + // complete. + if (!executable_ref->executable().dma_hints()->fully_deterministic()) { + return util::FailedPreconditionError(StringPrintf( + "Executable '%s' must have fully deterministic DMA " + "hints when DMA descriptors from device are disabled.", + executable_ref->executable().name()->c_str())); + } + } + + return {std::make_shared( + next_id_++, parent_request, executable_ref, allocator_.get(), + dram_allocator_.get(), + gtl::MakeUnique(&address_space_), + &dma_info_extractor_, + chip_config_->GetChipStructures().minimum_alignment_bytes, type)}; +} + +util::Status UsbDriver::DoSubmit(std::shared_ptr request) { + TRACE_SCOPE("UsbDriver::DoSubmit"); + StdMutexLock state_lock(&mutex_); + RETURN_IF_ERROR(ValidateStates({kOpen})); + + // Validate and prepare request. + RETURN_IF_ERROR(request->Validate()); + RETURN_IF_ERROR(request->Prepare()); + + RETURN_IF_ERROR(dma_scheduler_.Submit(std::move(request))); + + // Set the driver state to open and kick off processing. + RETURN_IF_ERROR(SetState(kOpen)); + + TRACE_WITHIN_SCOPE("UsbDriver::DoSubmit::Finished"); + return util::Status(); // OK +} + +util::Status UsbDriver::DoSetRealtimeMode(bool on) { + // TODO: Implementing real-time scheduler support for USB as + // well. + return util::FailedPreconditionError( + "This driver does not support real-time mode."); +} + +util::Status UsbDriver::DoSetExecutableTiming( + const ExecutableReference* executable, const api::Timing& timing) { + // TODO: Implementing real-time scheduler support for USB as + // well. + return util::FailedPreconditionError( + "This driver does not support real-time mode."); +} + +void UsbDriver::CheckFatalError(const util::Status& status) { + // TODO: Forward to the client application for handling. + CHECK_OK(status) << "Driver fatal error"; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_driver.h b/driver/usb/usb_driver.h new file mode 100644 index 0000000..bfcf0be --- /dev/null +++ b/driver/usb/usb_driver.h @@ -0,0 +1,463 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_DRIVER_H_ +#define DARWINN_DRIVER_USB_USB_DRIVER_H_ + +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include // NOLINT + +#include "api/buffer.h" +#include "driver/aligned_allocator.h" +#include "driver/config/apex_csr_offsets.h" +#include "driver/config/cb_bridge_csr_offsets.h" +#include "driver/config/chip_config.h" +#include "driver/config/scu_csr_offsets.h" +#include "driver/config/usb_csr_offsets.h" +#include "driver/device_buffer.h" +#include "driver/device_buffer_mapper.h" +#include "driver/dma_chunker.h" +#include "driver/dma_info.h" +#include "driver/dma_info_extractor.h" +#include "driver/driver.h" +#include "driver/interrupt/interrupt_controller_interface.h" +#include "driver/interrupt/top_level_interrupt_manager.h" +#include "driver/memory/dma_direction.h" +#include "driver/memory/dram_allocator.h" +#include "driver/memory/nop_address_space.h" +#include "driver/package_registry.h" +#include "driver/registers/registers.h" +#include "driver/request.h" +#include "driver/run_controller.h" +#include "driver/single_queue_dma_scheduler.h" +#include "driver/top_level_handler.h" +#include "driver/tpu_request.h" +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_dfu_commands.h" +#include "driver/usb/usb_io_request.h" +#include "driver/usb/usb_ml_commands.h" +#include "driver/usb/usb_registers.h" +#include "port/integral_types.h" +#include "port/statusor.h" +#include "port/stringprintf.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// DarwiNN USB driver. Thread safe. +class UsbDriver : public Driver { + public: + // Denotes how endpoints are operated. + enum class OperatingMode { + // Use independent endpoints to transfer instructions, input activations, + // and parameters with hardware flow control. + kMultipleEndpointsHardwareControl = 0, + + // Use independent endpoints to transfer instructions, input activations, + // and parameters with software flow control. + kMultipleEndpointsSoftwareQuery = 1, + + // Use single endpoint to transfer instructions, input activations, and + // parameters with simple hardware flow control. + kSingleEndpoint = 2, + }; + + // USB driver options. + struct UsbDriverOptions { + // USB EP operting mode. + OperatingMode mode{OperatingMode::kSingleEndpoint}; + + // If true, bulk-in data is transmitted in largest chunks possible. + // By default, driver uses 1KB chunk size for USB3 and 256B for USB2. + // This is part of workaround for b/73181174 + bool usb_force_largest_bulk_in_chunk_size{false}; + + /* + * There are only 3 modes of operation regarding + * usb_enable_bulk_out_descriptors_from_device and + * usb_enable_processing_of_hints: + * + * 1) both true, we follow the hints, and + * use descriptors sent from device as validation. This mode doesn't work + * if the device sends a lot of bulk-out or bulk-in descriptors out which + * could clog the descriptor/bulk-in pipeline. + * + * 2) disable descriptors but enable hints. We blindly follow the hints + * and send data to device as fast as we can. The mode is similar to the + * previous one, but could be slightly faster. + * + * 3) enable descriptors but disable the hints. we use descriptors from + * device and pretend there is no hint from code gen, except for first one + * (for instructions). This mode doesn't work with multiple instruction + * chunks, as device is not capable of generating descriptors for + * instructions. + * + */ + + // If true, all bulk in/out descriptors are enabled to be sent from + // device. + bool usb_enable_bulk_descriptors_from_device{false}; + + // If true, all hints from code generator are processed and followed. + bool usb_enable_processing_of_hints{true}; + + // Max number of concurrent async bulk transfers. + int usb_max_num_async_transfers{kDefaultMaxNumAsyncTransfers}; + + // Maximum amount of data to be sent to device in a single bulk out + // transfer. + uint32 max_bulk_out_transfer_size_in_bytes{ + kDefaultMaxBulkOutTransferSizeInBytes}; + + // Lower limit of credits in software flow control mode. + uint32 software_credits_lower_limit_in_bytes{ + kDefaultSoftwareCreditsLowerLimitInBytes}; + + // If true, the next queued request would be sent to device right after the + // current request gets into final extraction state. This feature is + // temporarily fixed at true and cannot be turned off. + bool usb_enable_overlapping_requests{true}; + + // If true, the fence between bulk-out and bulk-in would be lifted, allowing + // bulk-in to be issued before all bulk-out are finished. This feature could + // improve performance significantly on Android platform. + bool usb_enable_overlapping_bulk_in_and_out{true}; + + // If true, multiple bulk-in requests would be issued instead of just one at + // any moment. usb_enable_overlapping_bulk_in_and_out must also be true for + // this feature to be enabled. + bool usb_enable_queued_bulk_in_requests{true}; + + // If true, driver would fail to open if the current connection is low, + // full, or high speed. If the connection speed is not observable from + // underlying provider, this option is ignored. + bool usb_fail_if_slower_than_superspeed{false}; + + // General timeout for USB operations in milliseconds. + int usb_timeout_millis{6000}; + + // If non-empty, the firmware image to use for automatic DFU. + // This feature is only available when a device factory has been supplied. + std::vector usb_firmware_image; + + // If true, driver would always perform DFU at open. + // This feature is only available when a device factory has been supplied. + bool usb_always_dfu{true}; + + // Must be packet-size-aligned to avoid buffer overflow during bulk-in, + // which is 512-byte for USB2 HighSpeed and 1024-byte for USB3 SuperSpeed. + // TODO: Due to b/77531949, we can only set it to exactly 1024 for + // USB3 and 256 for USB2 for now. + size_t usb_bulk_in_max_chunk_size_in_bytes{1024}; + + // Max number of buffers to queue. + int usb_bulk_in_queue_capacity{32}; + }; + + // Constructs a device from the factory provided, and performs DFU according + // to options and discovered device state. Note that DFU requires closing and + // creating new device instances, and hence can only be achieved by supplying + // a factory. + UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::function>()> + device_factory, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, + std::unique_ptr time_stamper); + + // Constructs a driver instance around a supplied device object. + // Note that since no factory is provided, this driver cannot be re-opened + // after close. + UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::unique_ptr usb_device, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, + std::unique_ptr time_stamper); + + // This class is neither copyable nor movable. + UsbDriver(const UsbDriver&) = delete; + UsbDriver& operator=(const UsbDriver&) = delete; + + ~UsbDriver() override; + + uint64_t allocation_alignment_bytes() const override { + return chip_config_->GetChipStructures().allocation_alignment_bytes; + } + + protected: + util::Status DoOpen(bool debug_mode) LOCKS_EXCLUDED(mutex_) final; + util::Status DoClose(bool in_error, api::Driver::ClosingMode mode) + LOCKS_EXCLUDED(mutex_) final; + util::Status DoCancelAndWaitRequests(bool in_error) + LOCKS_EXCLUDED(mutex_) final; + + Buffer DoMakeBuffer(size_t size_bytes) const final; + util::StatusOr DoMapBuffer(const Buffer& buffer, + DmaDirection direction) final; + util::StatusOr> DoCreateRequest( + const std::shared_ptr parent_request, + const ExecutableReference* executable_ref, TpuRequest::RequestType type) + LOCKS_EXCLUDED(mutex_) final; + + util::Status DoSetExecutableTiming(const ExecutableReference* executable, + const api::Timing& timing) final; + util::Status DoSetRealtimeMode(bool on) final; + + util::Status DoSubmit(std::shared_ptr request_in) + LOCKS_EXCLUDED(mutex_) final; + + int64 MaxRemainingCycles() const override { + return dma_scheduler_.MaxRemainingCycles(); + } + + util::StatusOr> GetOldestActiveRequest() + const override { + return dma_scheduler_.GetOldestActiveRequest(); + } + + private: + // TODO: Eliminate state management here. Since this is now done + // in the base class. + // Driver state. Transitions : + // kClosed -> kOpen -> kClosing -> kClosed. + enum State { + kOpen, // Driver is Open. + kPaused, // Device has been paused. + kClosing, // Driver is Closing. + kClosed, // Driver is Closed. (Initial state.) + }; + + // Record information for a filled bulk-in buffer. + struct FilledBulkInInfo { + int buffer_index; + size_t begin_offset; + size_t end_offset; + }; + + // Default value for #usb_max_num_async_transfers, if not set by the client. + static constexpr int kDefaultMaxNumAsyncTransfers = 3; + + // Default value for #max_bulk_out_transfer_size_in_bytes, if not set by the + // client. + static constexpr uint32 kDefaultMaxBulkOutTransferSizeInBytes = 1024 * 1024; + + // Default value for #software_credits_lower_limit_in_bytes, if not set by the + // client. + static constexpr uint32 kDefaultSoftwareCreditsLowerLimitInBytes = 8 * 1024; + + // Constructor to be used as delegate target. + UsbDriver( + const api::DriverOptions& driver_options, + std::unique_ptr chip_config, + std::unique_ptr registers, + std::unique_ptr top_level_interrupt_manager, + std::unique_ptr + fatal_error_interrupt_controller, + std::unique_ptr top_level_handler, + std::unique_ptr dram_allocator, + std::unique_ptr executable_registry, + const UsbDriverOptions& options, + std::unique_ptr time_stamper); + + // Prepares USB device with resets and DFU according to options_. + util::Status PrepareUsbDevice(); + + // Creates a UsbMlCommands and assigns it to usb_device_, with timed retry. + util::Status OpenMlUsbDevice(); + + // Creates a raw USB device from device_factory_, with timed retry. + util::StatusOr> + CreateRawUsbDeviceWithRetry(); + + // Attempts a state transition to the given state. + util::Status SetState(State next_state) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Validates that we are in the expected state. + util::Status ValidateState(State expected_state) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Validates that we are in any of the expected states. + util::Status ValidateStates(const std::vector& expected_states) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Catches all fatal error handling during runtime. + void CheckFatalError(const util::Status& status); + + // Initializes the chip through CSR access. + util::Status InitializeChip() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Registers and enables all interrupts that come through interrupt + // endpoint. + util::Status RegisterAndEnableAllInterrupts(); + + // Disables all interrupts that come through interrupt endpoint. + util::Status DisableAllInterrupts(); + + // Runs the worker thread. + void WorkerThreadFunc() LOCKS_EXCLUDED(mutex_); + + // Handles bulk-in completion event from device. + void HandleQueuedBulkIn(const util::Status& status, int buffer_index, + size_t num_bytes_transferred); + + // Handles data in/out and software interrupt events sent from the device, + // in the worker thread. Thread safety analysis is confused by wrapping + // this function into a functor, and hence has to be disabled. + void HandleEvent(const util::Status& status, + const UsbMlCommands::EventDescriptor& event_info) + NO_THREAD_SAFETY_ANALYSIS; + + // Handles hardware interrupt events sent from the device, in the + // worker thread. Thread safety analysis is confused by wrapping this + // function into a functor, and hence has to be disabled. + void HandleInterrupt(const util::Status& status, + const UsbMlCommands::InterruptInfo& interrupt_info) + NO_THREAD_SAFETY_ANALYSIS; + + // Retrieves credits for the endpoint specified with tag. + uint32_t GetCredits(UsbMlCommands::DescriptorTag tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Processes data in/out requests associated with specified task. + util::StatusOr ProcessIo() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Records DMA descriptors and in-band interrupts sent from device. + util::Status HandleDmaDescriptor(UsbMlCommands::DescriptorTag tag, + uint64_t device_virtual_address, + uint32_t size_bytes, + bool bulk_events_enabled); + + util::Status CheckHibError(); + + std::function>()> + device_factory_; + + // The current active USB device supporting ML commands. + std::unique_ptr usb_device_; + + // CSR offsets. + std::unique_ptr chip_config_; + + // Implements CSR access. + std::unique_ptr registers_; + + // Buffer management. + // TODO: Allocate zero copy USB buffers. + std::unique_ptr allocator_; + + // Protects access to callback queue, resource shared by the worker thread + // and callbacks from usb device. + mutable std::mutex callback_mutex_; + + // Stores functors submitted by callbacks, to be executed in work thread. + std::queue> callback_queue_ GUARDED_BY(callback_mutex_); + + // Maintains integrity of the driver state. + mutable std::mutex mutex_; + + // Driver state. + State state_ GUARDED_BY(mutex_){kClosed}; + + // Conditional variable for worker thread to wait on events from both + // application layer and callbacks from usb devices. + std::condition_variable_any driver_state_changed_; + + // Worker thread object. + std::thread worker_thread_; + + // ID for tracking requests. + int next_id_ GUARDED_BY(mutex_){0}; + + // Top level interrupt controller. + std::unique_ptr top_level_interrupt_manager_; + + // Fatal error interrupt controller. + std::unique_ptr + fatal_error_interrupt_controller_; + + // Implements device run control. + std::unique_ptr run_controller_; + + // Implements reset handler. + std::unique_ptr top_level_handler_; + + // Enables driver to allocate buffer in On-Chip DRAM (if any). + std::unique_ptr dram_allocator_; + + // Address space management. + NopAddressSpace address_space_; + + // Driver options. + UsbDriverOptions options_; + + // DMA info extractor. + DmaInfoExtractor dma_info_extractor_; + + // DMA scheduler. + SingleQueueDmaScheduler dma_scheduler_; + + // list is used instead of vector, for iterators/pointers must be kept stable + // through out async callbacks. + std::list io_requests_; + + // If true, limit every bulk-in request to be at most 256-byte long. + // This is part of workaround for b/73181174 + bool cap_bulk_in_size_at_256_bytes_{false}; + + // Container for all bulk-in buffers. + std::vector bulk_in_buffers_; + + // Container for indices for bulk-in buffers that are not queued for data in. + // Note the reason for using queue here is for easier log interpretation. + std::queue available_bulk_in_buffers_; + + // Container for indices for bulk-in buffers that contain data from device. + std::queue filled_bulk_in_buffers_; + + // CSR offsets. + const config::ApexCsrOffsets& apex_csr_offsets_; + const config::CbBridgeCsrOffsets& cb_bridge_csr_offsets_; + const config::HibKernelCsrOffsets& hib_kernel_csr_offsets_; + const config::ScuCsrOffsets& scu_csr_offsets_; + const config::UsbCsrOffsets& usb_csr_offsets_; + const config::HibUserCsrOffsets& hib_user_csr_offsets_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_DRIVER_H_ diff --git a/driver/usb/usb_io_request.cc b/driver/usb/usb_io_request.cc new file mode 100644 index 0000000..5892633 --- /dev/null +++ b/driver/usb/usb_io_request.cc @@ -0,0 +1,101 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_io_request.h" + +#include "port/logging.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" +#include "port/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace driver { +namespace { + +// Returns IO type for given DMA. +UsbIoRequest::Type ConvertToIoType(const DmaInfo& info) { + switch (info.type()) { + case DmaDescriptorType::kInstruction: + case DmaDescriptorType::kInputActivation: + case DmaDescriptorType::kParameter: + return UsbIoRequest::Type::kBulkOut; + + case DmaDescriptorType::kOutputActivation: + return UsbIoRequest::Type::kBulkIn; + + case DmaDescriptorType::kScalarCoreInterrupt0: + case DmaDescriptorType::kScalarCoreInterrupt1: + case DmaDescriptorType::kScalarCoreInterrupt2: + case DmaDescriptorType::kScalarCoreInterrupt3: + return UsbIoRequest::Type::kScHostInterrupt; + + default: + LOG(FATAL) << "Cannot be converted"; + unreachable(); // NOLINT + } +} + +} // namespace + +UsbIoRequest::UsbIoRequest(int id, UsbMlCommands::DescriptorTag tag) + : id_(id), + source_and_match_status_( + UsbIoRequest::SourceAndMatchStatus::kSubmittedByDevice), + type_(Type::kScHostInterrupt), + tag_(tag), + chunker_(DmaChunker::HardwareProcessing::kCommitted, DeviceBuffer()) {} + +UsbIoRequest::UsbIoRequest(int id, UsbIoRequest::Type type, + UsbMlCommands::DescriptorTag tag, + const DeviceBuffer& buffer) + : id_(id), + source_and_match_status_( + UsbIoRequest::SourceAndMatchStatus::kSubmittedByDevice), + type_(type), + tag_(tag), + // Bytes transferred for BulkIn is determined by device, and may not match + // the descriptor. In such case, BulkIn will be delivered in chunks. + chunker_((type == Type::kBulkIn) + ? DmaChunker::HardwareProcessing::kBestEffort + : DmaChunker::HardwareProcessing::kCommitted, + buffer) {} + +UsbIoRequest::UsbIoRequest(DmaInfo* dma_info) + : id_([dma_info]() { + CHECK(dma_info != nullptr); + return dma_info->id(); + }()), + source_and_match_status_(SourceAndMatchStatus::kHintNotYetMatched), + type_(ConvertToIoType(*dma_info)), + tag_(static_cast(dma_info->type())), + // Bytes transferred for BulkIn is determined by device, and may not match + // the descriptor. In such case, BulkIn will be delivered in chunks. + chunker_((type_ == Type::kBulkIn) + ? DmaChunker::HardwareProcessing::kBestEffort + : DmaChunker::HardwareProcessing::kCommitted, + dma_info->buffer()), + dma_info_(dma_info) {} + +void UsbIoRequest::SetMatched() { + CHECK(dma_info_ != nullptr); + VLOG(9) << StringPrintf("DMA[%d] hint matched with descriptor", + dma_info_->id()); + source_and_match_status_ = SourceAndMatchStatus::kHintAlreadyMatched; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_io_request.h b/driver/usb/usb_io_request.h new file mode 100644 index 0000000..878eef8 --- /dev/null +++ b/driver/usb/usb_io_request.h @@ -0,0 +1,156 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_IO_REQUEST_H_ +#define DARWINN_DRIVER_USB_USB_IO_REQUEST_H_ + +#include + +#include "driver/device_buffer.h" +#include "driver/dma_chunker.h" +#include "driver/dma_info.h" +#include "driver/usb/usb_ml_commands.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Entry for either a hint, generated by code generator at compile time, or a +// request, submitted by device at run time. +class UsbIoRequest { + public: + // Basic type of the hint/request. + enum class Type { + // This is a bulk out request. Data is sent from host to device. + kBulkOut = 0, + + // This is a bulk in request. Data is sent from device to host. + kBulkIn, + + // This is an interrupt event. Signal is sent from device to host. + kScHostInterrupt, + }; + + enum class SourceAndMatchStatus { + // This is a hint, and we haven't seen a matching request from device. + kHintNotYetMatched = 0, + + // This is a hint, and we have received a matching request from device. + kHintAlreadyMatched, + + // This is a request submitted by device. + kSubmittedByDevice, + }; + + // Constructor for device descriptors. + UsbIoRequest(int id, UsbMlCommands::DescriptorTag tag); + UsbIoRequest(int id, Type type, UsbMlCommands::DescriptorTag tag, + const DeviceBuffer& buffer); + + // Constructor for DMA hints. + UsbIoRequest(DmaInfo* dma_info); + + // Accessors. + Type GetType() const { return type_; } + UsbMlCommands::DescriptorTag GetTag() const { return tag_; } + SourceAndMatchStatus GetSourceAndMatchStatus() const { + return source_and_match_status_; + } + const DeviceBuffer& GetBuffer() const { return chunker_.buffer(); } + + // Returns true if in given state. + bool IsHeaderSent() const { return !header_.empty(); } + bool IsActive() const { return chunker_.IsActive(); } + bool IsCompleted() const { + if (!chunker_.IsCompleted()) { + return false; + } + // Interrupt will be always come through descriptor path. + // It is completed when it was matched with hint, or submitted by device. + if (type_ == Type::kScHostInterrupt) { + return source_and_match_status_ != + SourceAndMatchStatus::kHintNotYetMatched; + } + return true; + } + + // Returns true if there is chunk to transfer. + bool HasNextChunk() const { return chunker_.HasNextChunk(); } + + // Returns not yet transferred chunk. + DeviceBuffer GetNextChunk() { return chunker_.GetNextChunk(); } + + // Returns a next chunk upto "num_bytes". + DeviceBuffer GetNextChunk(int num_bytes) { + return chunker_.GetNextChunk(num_bytes); + } + + // Notifies that "num_bytes" of transfer is completed. + void NotifyTransferComplete(int num_bytes) { + chunker_.NotifyTransfer(num_bytes); + } + + // Returns number of active chunks assuming each chunk is "bytes". + int GetActiveCounts(int bytes) const { + return chunker_.GetActiveCounts(bytes); + } + + // Marks this hint as it has been matched with a request/event sent from + // device. + void SetMatched(); + + // Returns header. + const std::vector& header() const { return header_; } + + // Sets header. + void SetHeader(const std::vector& header) { header_ = header; } + void SetHeader(std::vector&& header) { header_ = std::move(header); } + + // Returns true if created from DMA hint. + bool FromDmaHint() const { return dma_info_ != nullptr; } + DmaInfo* dma_info() const { return dma_info_; } + + // Returns id. + int id() const { return id_; } + + private: + // ID for debugging purpose. + const int id_; + + // Is this a hint or a request. If it's a hint, has it been matched with + // soemthing sent from device? + SourceAndMatchStatus source_and_match_status_; + + // Basic type of the hint/request. + const Type type_; + + // Tag is more detailed information under Type. For example, a bulk out type + // could be instruction, parameters, or input activations. + const UsbMlCommands::DescriptorTag tag_; + + // DMA chunker. + DmaChunker chunker_; + + // Contains valid pointer to DMA info if this is for hint. + DmaInfo* dma_info_{nullptr}; + + // Stores header used in single endpoint mode. + std::vector header_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_IO_REQUEST_H_ diff --git a/driver/usb/usb_ml_commands.cc b/driver/usb/usb_ml_commands.cc new file mode 100644 index 0000000..bb4aaf5 --- /dev/null +++ b/driver/usb/usb_ml_commands.cc @@ -0,0 +1,307 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_ml_commands.h" + +#include +#include + +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_standard_commands.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +namespace { + +constexpr size_t kRegister64RawDataSizeInBytes = 8; +constexpr size_t kRegister32RawDataSizeInBytes = 4; +constexpr size_t kInterruptRawDataSizeInBytes = 4; +constexpr size_t kEventRawDataSizeInBytes = 16; +constexpr size_t kPacketHeaderRawDataSizeInBytes = 8; + +} // namespace + +UsbMlCommands::UsbMlCommands(std::unique_ptr device, + TimeoutMillis default_timeout_msec) + : UsbStandardCommands(std::move(device), default_timeout_msec) { + VLOG(10) << __func__; +} + +UsbMlCommands::~UsbMlCommands() { VLOG(10) << __func__; } + +util::Status UsbMlCommands::DfuDetach(int interface_number, + uint16_t timeout_msec) { + VLOG(10) << StringPrintf("%s interface %d, timeout %u msec", __func__, + interface_number, timeout_msec); + + SetupPacket command{ + // Request type (00100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kClass, + CommandRecipient::kInterface), + + // Request id for DFU detach command + 0, + + // Timeout in milliseconds. + timeout_msec, + + // Interface number. + static_cast(interface_number), + + // Data length. + 0}; + + RETURN_IF_ERROR(SendControlCommand(command, __func__)); + + return Close(CloseAction::kGracefulPortReset); +} + +util::StatusOr UsbMlCommands::ReadRegister32( + uint32_t offset) { + VLOG(10) << StringPrintf("%s offset 0x%x", __func__, offset); + + Register32 result; + SetupPacket command{ + // Request type (0xC0). + ComposeUsbRequestType(CommandDataDir::kDeviceToHost, CommandType::kVendor, + CommandRecipient::kDevice), + + // Request id for read CSR 32-bit command. + 1, + + // Low 16-bit of the offset. + static_cast(offset & 0xffff), + + // High 16-bit of the offset. + static_cast(offset >> 16), + + // Data length. + kRegister32RawDataSizeInBytes}; + + size_t num_bytes_transferred = 0; + RETURN_IF_ERROR(SendControlCommandWithDataIn( + command, + MutableBuffer(reinterpret_cast(&result), sizeof(result)), + &num_bytes_transferred, __func__)); + + if (num_bytes_transferred != sizeof(result)) { + return util::UnknownError("Invalid register data"); + } + + VLOG(7) << StringPrintf("%s [0x%X] == 0x%" PRIX32, __func__, offset, result); + + return result; +} + +util::StatusOr UsbMlCommands::ReadRegister64( + uint32_t offset) { + VLOG(10) << StringPrintf("%s offset 0x%x", __func__, offset); + + Register64 result; + SetupPacket command{ + // Request type (0xC0). + ComposeUsbRequestType(CommandDataDir::kDeviceToHost, CommandType::kVendor, + CommandRecipient::kDevice), + + // Request id for read CSR 64-bit command. + 0, + + // Low 16-bit of the offset. + static_cast(offset & 0xffff), + + // High 16-bit of the offset. + static_cast(offset >> 16), + + // Data length. + kRegister64RawDataSizeInBytes}; + + size_t num_bytes_transferred = 0; + RETURN_IF_ERROR(SendControlCommandWithDataIn( + command, + MutableBuffer(reinterpret_cast(&result), sizeof(result)), + &num_bytes_transferred, __func__)); + + if (num_bytes_transferred != sizeof(result)) { + return util::UnknownError("Invalid register data"); + } + + VLOG(7) << StringPrintf("%s [0x%X] == 0x%" PRIX64, __func__, offset, result); + + return result; +} + +util::Status UsbMlCommands::WriteRegister32(uint32_t offset, Register32 value) { + VLOG(7) << StringPrintf("%s [0x%X] := 0x%" PRIX32, __func__, offset, value); + + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kVendor, + CommandRecipient::kDevice), + + // Request id for write CSR 32-bit command. + 1, + + // Low 16-bit of the offset. + static_cast(offset & 0xffff), + + // High 16-bit of the offset. + static_cast(offset >> 16), + + // Data length. + sizeof(value)}; + + return SendControlCommandWithDataOut( + command, ConstBuffer(reinterpret_cast(&value), sizeof(value)), + __func__); +} + +util::Status UsbMlCommands::WriteRegister64(uint32_t offset, Register64 value) { + VLOG(7) << StringPrintf("%s [0x%X] := 0x%" PRIX64, __func__, offset, value); + + SetupPacket command{ + // Request type (10100001b). + ComposeUsbRequestType(CommandDataDir::kHostToDevice, CommandType::kVendor, + CommandRecipient::kDevice), + + // Request id for write CSR 64-bit command. + 0, + + // Low 16-bit of the offset. + static_cast(offset & 0xffff), + + // High 16-bit of the offset. + static_cast(offset >> 16), + + // Data length. + sizeof(value)}; + + return SendControlCommandWithDataOut( + command, ConstBuffer(reinterpret_cast(&value), sizeof(value)), + __func__); +} + +std::vector UsbMlCommands::PrepareHeader(DescriptorTag tag, + uint32_t length) { + // Write 8-byte-long tag. + constexpr size_t kLengthSizeInBytes = 4; + + // length must be 4-byte long, otherwise we could be sending wrong data. + CHECK_EQ(sizeof(length), kLengthSizeInBytes); + + std::vector header_packet(kPacketHeaderRawDataSizeInBytes); + memcpy(header_packet.data(), &length, kLengthSizeInBytes); + *(header_packet.data() + kLengthSizeInBytes) = + (static_cast(tag) & 0xF); + + VLOG(10) << StringPrintf( + "%s ep %d: header hex %2x %2x %2x %2x - %2x %2x %2x %2x", __func__, + kSingleBulkOutEndpoint, header_packet[0], header_packet[1], + header_packet[2], header_packet[3], header_packet[4], header_packet[5], + header_packet[6], header_packet[7]); + + return header_packet; +} + +util::Status UsbMlCommands::WriteHeader(DescriptorTag tag, uint32_t length) { + std::vector header_packet = PrepareHeader(tag, length); + return BulkOutTransfer(kSingleBulkOutEndpoint, ConstBuffer(header_packet), + __func__); +} + +util::Status UsbMlCommands::AsyncReadEvent(const EventInDone& callback) { + auto event_data = + std::make_shared>(kEventRawDataSizeInBytes); + CHECK(event_data); + return AsyncBulkInTransfer( + kEventInEndpoint, MutableBuffer(event_data->data(), event_data->size()), + [event_data, callback](util::Status status, + size_t num_bytes_transferred) { + EventDescriptor event_descriptor; + if (!status.ok()) { + callback(status, event_descriptor); + return; + } + if (num_bytes_transferred != kEventRawDataSizeInBytes) { + VLOG(1) << StringPrintf("%s data lost. calling with empty event", + __func__); + callback(util::DataLossError(__func__), event_descriptor); + return; + } + memcpy(&event_descriptor.offset, event_data->data(), + sizeof(event_descriptor.offset)); + constexpr size_t kAddressSizeInBytes = 8; + constexpr size_t kLengthSizeInBytes = 4; + memcpy(&event_descriptor.length, + event_data->data() + kAddressSizeInBytes, kLengthSizeInBytes); + event_descriptor.tag = static_cast( + *(event_data->data() + kAddressSizeInBytes + kLengthSizeInBytes) & + 0xF); + + VLOG(7) << StringPrintf( + "%s tag:%d, offset:0x%" PRIX64 ", length %u", __func__, + static_cast(event_descriptor.tag), event_descriptor.offset, + event_descriptor.length); + + // OK. + callback(status, event_descriptor); + + VLOG(7) << StringPrintf("%s callback done", __func__); + }, + __func__); +} + +util::Status UsbMlCommands::AsyncReadInterrupt( + const InterruptInDone& callback) { + auto interrupt_data = + std::make_shared>(kInterruptRawDataSizeInBytes); + CHECK(interrupt_data); + return AsyncInterruptInTransfer( + kInterruptInEndpoint, + MutableBuffer(interrupt_data->data(), interrupt_data->size()), + [interrupt_data, callback](util::Status status, + size_t num_bytes_transferred) { + InterruptInfo interrupt_info = {0}; + if (!status.ok()) { + callback(status, interrupt_info); + return; + } + if (num_bytes_transferred != kInterruptRawDataSizeInBytes) { + callback(util::DataLossError(__func__), interrupt_info); + return; + } + memcpy(&interrupt_info.raw_data, interrupt_data->data(), + sizeof(interrupt_info.raw_data)); + VLOG(7) << StringPrintf("%s raw data 0x%X", __func__, + interrupt_info.raw_data); + + // OK. + callback(status, interrupt_info); + + VLOG(7) << StringPrintf("%s callback done", __func__); + }, + __func__); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_ml_commands.h b/driver/usb/usb_ml_commands.h new file mode 100644 index 0000000..4ec79a9 --- /dev/null +++ b/driver/usb/usb_ml_commands.h @@ -0,0 +1,132 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_ML_COMMANDS_H_ +#define DARWINN_DRIVER_USB_USB_ML_COMMANDS_H_ + +#include "driver/usb/usb_device_interface.h" +#include "driver/usb/usb_standard_commands.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// Thread-safe implementation of Machine Learning commands through USB +// interface. +class UsbMlCommands : public UsbStandardCommands { + public: + enum class DescriptorTag { + kUnknown = -1, + kInstructions = 0, + kInputActivations = 1, + kParameters = 2, + kOutputActivations = 3, + kInterrupt0 = 4, + kInterrupt1 = 5, + kInterrupt2 = 6, + kInterrupt3 = 7, + }; + + struct EventDescriptor { + DescriptorTag tag{DescriptorTag::kUnknown}; + uint32_t length; + uint64_t offset; + }; + + // Contains information retrieved from USB interrupt. + // TODO: further parse this raw_data and provide more readable + // information. + struct InterruptInfo { + uint32_t raw_data; + }; + + using InterruptInDone = + std::function; + using EventInDone = std::function; + + using Register32 = uint32_t; + using Register64 = uint64_t; + + // Bulk out endpoint id used in single bulk out endpoint mode. + static constexpr uint8_t kSingleBulkOutEndpoint = 1; + + // Bulk out endpoint id used for instruction stream in multiple bulk out mode. + static constexpr uint8_t kInstructionsEndpoint = 1; + + // Bulk out endpoint id used for input activation stream in multiple bulk out + // mode. + static constexpr uint8_t kInputActivationsEndpoint = 2; + + // Bulk out endpoint id used for parameter stream in multiple bulk out mode. + static constexpr uint8_t kParametersEndpoint = 3; + + // Bulk in endpoint id used for output activation stream. + static constexpr uint8_t kBulkInEndpoint = 1; + + // Bulk in endpoint id used for event stream. + static constexpr uint8_t kEventInEndpoint = 2; + + // Interrupt in endpoint id used for interrupt stream. + static constexpr uint8_t kInterruptInEndpoint = 3; + + // Constructs a new object from pointer to an USB device. + UsbMlCommands(std::unique_ptr device, + TimeoutMillis default_timeout_msec); + + // This class is neither copyable nor movable. + UsbMlCommands(const UsbMlCommands&) = delete; + UsbMlCommands& operator=(const UsbMlCommands&) = delete; + + ~UsbMlCommands() override; + + // Detaches from application mode, then closes the device with graceful port + // reset. The object becomes closed upon successful return. + // Note the interface number and timeout for detachment have to come from + // device configuration or discovered through parsing the interface + // descriptors. + util::Status DfuDetach(int interface_number, uint16_t timeout_msec); + + // Reads 32-bit CSR from device. + util::StatusOr ReadRegister32(uint32_t offset); + + // Reads 64-bit CSR from device. + util::StatusOr ReadRegister64(uint32_t offset); + + // Writes 32-bit CSR to device. + util::Status WriteRegister32(uint32_t offset, Register32 value); + + // Writes 64-bit CSR to device. + util::Status WriteRegister64(uint32_t offset, Register64 value); + + // Writes header to device, through the single bulk out endpoint. This + // function is only meaningful if the device is in single bulk out endpoint + // mode. + util::Status WriteHeader(DescriptorTag tag, uint32_t length); + + // Prepares header to be sent to device. This function + // is only meaningful if the device is in single bulk out endpoint mode. + std::vector PrepareHeader(DescriptorTag tag, uint32_t length); + + // Asynchrounously reads event from device. + util::Status AsyncReadEvent(const EventInDone& callback); + + // Asynchrounously read interrupt from device. + util::Status AsyncReadInterrupt(const InterruptInDone& callback); +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_ML_COMMANDS_H_ diff --git a/driver/usb/usb_registers.cc b/driver/usb/usb_registers.cc new file mode 100644 index 0000000..5e6be77 --- /dev/null +++ b/driver/usb/usb_registers.cc @@ -0,0 +1,76 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_registers.h" + +#include "driver/usb/usb_ml_commands.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +util::Status UsbRegisters::Open() { + return util::UnimplementedError("USB register open without attached device"); +} + +util::Status UsbRegisters::Open(UsbMlCommands* usb_device) { + usb_device_ = usb_device; + return util::Status(); // OK +} + +util::Status UsbRegisters::Close() { + usb_device_ = nullptr; + return util::Status(); // OK +} + +util::Status UsbRegisters::Write(uint64 offset, uint64 value) { + if (usb_device_) { + return usb_device_->WriteRegister64(static_cast(offset), value); + } + return util::FailedPreconditionError( + "USB register write without attached device"); +} + +util::StatusOr UsbRegisters::Read(uint64 offset) { + if (usb_device_) { + return usb_device_->ReadRegister64(static_cast(offset)); + } + return util::FailedPreconditionError( + "USB register read without attached device"); +} + +util::Status UsbRegisters::Write32(uint64 offset, uint32 value) { + if (usb_device_) { + return usb_device_->WriteRegister32(static_cast(offset), value); + } + return util::FailedPreconditionError( + "USB register write32 without attached device"); +} + +util::StatusOr UsbRegisters::Read32(uint64 offset) { + if (usb_device_) { + return usb_device_->ReadRegister32(static_cast(offset)); + } + return util::FailedPreconditionError( + "USB register read32 without attached device"); +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_registers.h b/driver/usb/usb_registers.h new file mode 100644 index 0000000..b150b85 --- /dev/null +++ b/driver/usb/usb_registers.h @@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_REGISTERS_H_ +#define DARWINN_DRIVER_USB_USB_REGISTERS_H_ + +#include + +#include "driver/registers/registers.h" +#include "driver/usb/usb_ml_commands.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +class UsbRegisters : public Registers { + public: + UsbRegisters() = default; + // This version without argument should never be used. Use the version with + // pointer to USB deviec instead. + util::Status Open() override; + // Enable the USB register object the actually communicate with the underlying + // device. + util::Status Open(UsbMlCommands* usb_device); + util::Status Close() override; + + // Accesses 64-bit registers. + util::Status Write(uint64 offset, uint64 value) override; + util::StatusOr Read(uint64 offset) override; + + // Accesses 32-bit registers. + util::Status Write32(uint64 offset, uint32 value) override; + util::StatusOr Read32(uint64 offset) override; + + private: + // Underlying device. + UsbMlCommands* usb_device_{nullptr}; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_REGISTERS_H_ diff --git a/driver/usb/usb_standard_commands.cc b/driver/usb/usb_standard_commands.cc new file mode 100644 index 0000000..e371b7c --- /dev/null +++ b/driver/usb/usb_standard_commands.cc @@ -0,0 +1,134 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "driver/usb/usb_standard_commands.h" + +#include + +#include "driver/usb/usb_device_interface.h" +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +UsbStandardCommands::UsbStandardCommands( + std::unique_ptr device, + TimeoutMillis default_timeout_msec) + : device_(std::move(device)), default_timeout_msec_(default_timeout_msec) { + VLOG(10) << __func__; +} + +UsbStandardCommands::~UsbStandardCommands() { VLOG(10) << __func__; } + +util::StatusOr +UsbStandardCommands::GetDeviceDescriptor() { + VLOG(10) << __func__; + // The raw size of standard USB device descriptor. + constexpr size_t kDeviceDescriptorRawByteSize = 18; + uint8 descriptor_buffer[kDeviceDescriptorRawByteSize]; + size_t num_bytes_transferred = 0; + + RETURN_IF_ERROR( + GetDescriptor(UsbDeviceInterface::DescriptorType::kDevice, 0, + gtl::MutableArraySlice(descriptor_buffer, + sizeof(descriptor_buffer)), + &num_bytes_transferred, __func__)); + + if (num_bytes_transferred < kDeviceDescriptorRawByteSize) { + return util::UnknownError("Device descriptor is too short"); + } + + DeviceDescriptor descriptor = {0}; + + // Fill fields from raw bytes to more accessible data structure. + // memcpy is used for all multi-byte data fields to avoid any potential + // alignment issues. + memcpy(&descriptor.usb_version_bcd, descriptor_buffer + 2, 2); + descriptor.device_class = + static_cast(descriptor_buffer[4]); + descriptor.device_subclass = descriptor_buffer[5]; + descriptor.bDeviceProtocol = descriptor_buffer[6]; + descriptor.max_packet_size_0 = descriptor_buffer[7]; + memcpy(&descriptor.vendor_id, descriptor_buffer + 8, 2); + memcpy(&descriptor.product_id, descriptor_buffer + 10, 2); + memcpy(&descriptor.device_version_bcd, descriptor_buffer + 12, 2); + descriptor.manufacturer_name_index = descriptor_buffer[14]; + descriptor.product_name_index = descriptor_buffer[15]; + descriptor.serial_number_index = descriptor_buffer[16]; + descriptor.num_configurations = descriptor_buffer[17]; + + VLOG(7) << StringPrintf("Vender ID: 0x%x", descriptor.vendor_id); + VLOG(7) << StringPrintf("Product ID: 0x%x", descriptor.product_id); + + return descriptor; +} + +util::StatusOr +UsbStandardCommands::GetConfigurationDescriptor(uint8_t index, + size_t max_extra_data_length) { + VLOG(10) << StringPrintf("%s index %d", __func__, index); + // The raw size of standard USB device descriptor. + constexpr size_t kConfigDescriptorRawByteSize = 9; + const size_t total_data_length = + kConfigDescriptorRawByteSize + max_extra_data_length; + size_t num_bytes_transferred = 0; + ConfigurationDescriptor descriptor; + descriptor.raw_data.resize(total_data_length); + + RETURN_IF_ERROR( + GetDescriptor(UsbDeviceInterface::DescriptorType::kConfig, 0, + gtl::MutableArraySlice(descriptor.raw_data.data(), + descriptor.raw_data.size()), + &num_bytes_transferred, __func__)); + + if (num_bytes_transferred < kConfigDescriptorRawByteSize) { + return util::UnknownError("Device descriptor is too short"); + } + + descriptor.raw_data.resize(num_bytes_transferred); + + descriptor.num_interfaces = descriptor.raw_data[4]; + descriptor.configuration_value = descriptor.raw_data[5]; + descriptor.configuration_name_index = descriptor.raw_data[6]; + const uint8_t attributes = descriptor.raw_data[7]; + descriptor.is_self_powered = (attributes >> 6) & 1; + descriptor.supports_remote_wakeup = (attributes >> 5) & 1; + descriptor.encoded_max_power = descriptor.raw_data[8]; + + VLOG(7) << StringPrintf("Configuration requested: %d", index); + VLOG(7) << StringPrintf("Configuration reported: %d", + descriptor.configuration_value); + VLOG(7) << StringPrintf("Number of interfaces: %u", + descriptor.num_interfaces); + VLOG(7) << StringPrintf("Is self powered: %d", descriptor.is_self_powered); + VLOG(7) << StringPrintf("Supports remote wakeup: %d", + descriptor.supports_remote_wakeup); + VLOG(7) << StringPrintf("Encoded max power: 0x%x", + descriptor.is_self_powered); + VLOG(7) << StringPrintf("Raw data size: %d", + static_cast(descriptor.raw_data.size())); + + return descriptor; +} + +} // namespace driver +} // namespace darwinn +} // namespace platforms diff --git a/driver/usb/usb_standard_commands.h b/driver/usb/usb_standard_commands.h new file mode 100644 index 0000000..620a340 --- /dev/null +++ b/driver/usb/usb_standard_commands.h @@ -0,0 +1,276 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_DRIVER_USB_USB_STANDARD_COMMANDS_H_ +#define DARWINN_DRIVER_USB_USB_STANDARD_COMMANDS_H_ + +#include "driver/usb/usb_device_interface.h" + +namespace platforms { +namespace darwinn { +namespace driver { + +// TODO provide a mechanism like locked/unlocked versions of functions +// and a busy state to prevent any interruption in the middle of a long sequence +// like firmware update. +class UsbStandardCommands { + public: + using ConstBuffer = UsbDeviceInterface::ConstBuffer; + using MutableBuffer = UsbDeviceInterface::MutableBuffer; + using DataInDone = UsbDeviceInterface::DataInDone; + using DataOutDone = UsbDeviceInterface::DataOutDone; + using CloseAction = UsbDeviceInterface::CloseAction; + using TimeoutMillis = UsbDeviceInterface::TimeoutMillis; + using SetupPacket = UsbDeviceInterface::SetupPacket; + using CommandDataDir = UsbDeviceInterface::CommandDataDir; + using CommandType = UsbDeviceInterface::CommandType; + using CommandRecipient = UsbDeviceInterface::CommandRecipient; + using DeviceSpeed = UsbDeviceInterface::DeviceSpeed; + + // Device descriptor can be retrieved from device by GetDeviceDescriptor. + // Detailed meaning of each field can be found in USB spec. + // Data in this structure determines how host would identify this device. + struct DeviceDescriptor { + // USB spec release number in BCD. + uint16 usb_version_bcd; + + // Class of this device. + UsbDeviceInterface::DeviceClass device_class; + + // Sub class of this device + uint8 device_subclass; + + // Protocol this device speaks. + uint8 bDeviceProtocol; + + // Packet size for endpoint 0 + uint8 max_packet_size_0; + + // Vendor ID. + uint16 vendor_id; + + // Product ID. + uint16 product_id; + + // Device release number in BCD. + uint16 device_version_bcd; + + // Name of the manufacturer, as an index into string descriptors. + uint8_t manufacturer_name_index; + + // Name of the product, as an index into string descriptors. + uint8_t product_name_index; + + // Serial number in string, as an index into string descriptors. + uint8_t serial_number_index; + + // Number of supported configurations. + uint8 num_configurations; + }; + + // Configuration descriptor can be retrieved from device by + // GetConfigurationDescriptor. + // Detailed meaning of each field can be found in USB spec. + struct ConfigurationDescriptor { + // Number of interfaces supported by this configuration. + uint8_t num_interfaces; + + // ID of this configuration, to be used in Set configuration. + uint8_t configuration_value; + + // Name of this configuration, as an index into string descriptors. + uint8_t configuration_name_index; + + // True if this device is self-powered, and hence doesn't need host to + // provide any power. + bool is_self_powered; + + // True if this device supports remote wake up feature. + bool supports_remote_wakeup; + + // Max current could be drawn from host by this device. Note the encoding is + // speed-specific. + uint8_t encoded_max_power; + + // All descriptors for this configurations are returned, if allowed by + // buffer size. Further parsing has to be done on this extra data to learn + // about interfaces, endpoints, and other descriptors under this + // configuration. + std::vector raw_data; + }; + + // Interface descriptor can be retrieved from device as by-product in + // GetConfigurationDescriptor. + // Detailed meaning of each field can be found in USB spec. + struct InterfaceDescriptor { + // ID of this interface to be used in set interface. + uint8_t interface_number; + + // ID of alternate setting among several very similuar and mutural exclusing + // interfaces. + uint8_t alternate_setting; + + // Number of endpoints, other than the control endpoint, in this interface. + uint8_t num_endpoints; + + // Classe code is defined by USB IF. + uint8_t interface_class; + + // Sub-class code is defined by USB IF. + uint8_t interface_subclass; + + // Protocol code is defined by USB IF. + uint8_t interface_protocol; + + // Name of the interface, as an index into string descriptors. + uint8_t interface_name_index; + }; + + UsbStandardCommands(std::unique_ptr device, + TimeoutMillis default_timeout_msec); + + // This class is neither copyable nor movable. + UsbStandardCommands(const UsbStandardCommands&) = delete; + UsbStandardCommands& operator=(const UsbStandardCommands&) = delete; + + virtual ~UsbStandardCommands(); + + util::Status Close(CloseAction action) { return device_->Close(action); } + + util::Status SetConfiguration(int configuration) { + return device_->SetConfiguration(configuration); + } + + util::Status ClaimInterface(int interface_number) { + return device_->ClaimInterface(interface_number); + } + + util::Status ReleaseInterface(int interface_number) { + return device_->ReleaseInterface(interface_number); + } + + util::Status GetDescriptor(UsbDeviceInterface::DescriptorType desc_type, + uint8_t desc_index, MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) { + // TODO add warnings/limitations on what can be queried, according + // to USB 3 spec. Only device, config, string, and BOS types can be queried. + // Only config and string types can have non-zero index specified. + // Some devices do respond to more types, but this is device-specific. + return device_->GetDescriptor(desc_type, desc_index, data_in, + num_bytes_transferred, context); + } + + DeviceSpeed GetDeviceSpeed() const { return device_->GetDeviceSpeed(); } + + util::Status SendControlCommand(const SetupPacket& command, + const char* context) { + return device_->SendControlCommand(command, default_timeout_msec_, context); + } + + util::Status SendControlCommandWithDataOut(const SetupPacket& command, + ConstBuffer data_out, + const char* context) { + return device_->SendControlCommandWithDataOut( + command, data_out, default_timeout_msec_, context); + } + + util::Status SendControlCommandWithDataIn(const SetupPacket& command, + MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) { + return device_->SendControlCommandWithDataIn( + command, data_in, num_bytes_transferred, default_timeout_msec_, + context); + } + + util::Status BulkOutTransfer(uint8_t endpoint, ConstBuffer data_out, + const char* context) { + return device_->BulkOutTransfer(endpoint, data_out, default_timeout_msec_, + context); + } + + util::Status BulkInTransfer(uint8_t endpoint, MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) { + return device_->BulkInTransfer(endpoint, data_in, num_bytes_transferred, + default_timeout_msec_, context); + } + + util::Status InterruptInTransfer(uint8_t endpoint, MutableBuffer data_in, + size_t* num_bytes_transferred, + const char* context) { + return device_->InterruptInTransfer(endpoint, data_in, + num_bytes_transferred, + default_timeout_msec_, context); + } + + util::Status AsyncBulkOutTransfer(uint8_t endpoint, ConstBuffer data_out, + DataOutDone callback, const char* context) { + return device_->AsyncBulkOutTransfer(endpoint, data_out, + default_timeout_msec_, + std::move(callback), context); + } + + util::Status AsyncBulkInTransfer(uint8_t endpoint, MutableBuffer data_in, + DataInDone callback, const char* context) { + return device_->AsyncBulkInTransfer( + endpoint, data_in, default_timeout_msec_, std::move(callback), context); + } + + util::Status AsyncInterruptInTransfer(uint8_t endpoint, MutableBuffer data_in, + DataInDone callback, + const char* context) { + return device_->AsyncInterruptInTransfer( + endpoint, data_in, default_timeout_msec_, std::move(callback), context); + } + + void TryCancelAllTransfers() { device_->TryCancelAllTransfers(); } + + util::StatusOr AllocateTransferBuffer(size_t buffer_size) { + return device_->AllocateTransferBuffer(buffer_size); + } + + util::Status ReleaseTransferBuffer(MutableBuffer buffer) { + return device_->ReleaseTransferBuffer(buffer); + } + + uint8_t ComposeUsbRequestType(CommandDataDir dir, CommandType type, + CommandRecipient recipient) { + return device_->ComposeUsbRequestType(dir, type, recipient); + } + + // Retrieves device descriptor from device. In some implementations, a cached + // one could be returned. + util::StatusOr GetDeviceDescriptor(); + + // Retrieves configuration descriptor from device. In some implementations, a + // cached one could be returned. + util::StatusOr GetConfigurationDescriptor( + uint8_t index, size_t max_extra_data_length); + + TimeoutMillis GetDefaultTimeoutMillis() const { + return default_timeout_msec_; + } + + private: + std::unique_ptr device_; + const TimeoutMillis default_timeout_msec_; +}; + +} // namespace driver +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_DRIVER_USB_USB_STANDARD_COMMANDS_H_ diff --git a/executable/BUILD b/executable/BUILD new file mode 100644 index 0000000..68f279e --- /dev/null +++ b/executable/BUILD @@ -0,0 +1,109 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Contains definitions for the Darwinn Executable format that serves as the +# compiler / runtime and firmware contract. + +# Keep the following two load calls seperate, otherwise copybara can not +# clean it up properly. +load( + "@flatbuffers//:build_defs.bzl", + "flatbuffer_cc_library", + "flatbuffer_py_library", +) + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "export_graph_fbs", + data = [ + "graph.fbs", + ], +) + +flatbuffer_cc_library( + name = "executable_fbs", + srcs = ["executable.fbs"], + flatc_args = [ + "--gen-object-api", # Adds 46KB to the generated code. + "--force-empty", + "--gen-mutable", # No size increase to the generated code. + ], +) + +# The extended version has a larger code size and is meant for tools/tests. +# Make sure it does not overlap with executable_fbs. +flatbuffer_cc_library( + name = "executable_fbs_extended", + srcs = ["executable.fbs"], + flatc_args = [ + "--gen-object-api", + "--force-empty", + "--reflect-names", + ], + out_prefix = "extended/", +) + +# Flatbuffer library for graph that would be interpreted by the runtime. +flatbuffer_cc_library( + name = "graph_fbs", + srcs = ["graph.fbs"], + flatc_args = [ + "--gen-object-api", + "--force-empty", + "--gen-mutable", + ], + include_paths = [ + ".", + "third_party/darwinn/executable", + ], +) + +# Flatbuffer library for graph that would be interpreted by the runtime. +# The extended version has a larger code size and is meant for tools/tests. +# Make sure it does not overlap with graph_fbs. +flatbuffer_cc_library( + name = "graph_fbs_extended", + srcs = ["graph.fbs"], + flatc_args = [ + "--gen-object-api", + "--force-empty", + "--reflect-names", + ], + include_paths = [ + ".", + "third_party/darwinn/executable", + ], + out_prefix = "extended/", +) + +filegroup( + name = "graph_schema", + srcs = [ + "graph.fbs", + ], +) + +flatbuffer_cc_library( + name = "vii_commands_fbs", + srcs = ["vii_commands.fbs"], + flatc_args = [ + "--gen-object-api", + "--force-empty", + "--gen-mutable", + ], +) diff --git a/executable/executable.fbs b/executable/executable.fbs new file mode 100644 index 0000000..3ffe28b --- /dev/null +++ b/executable/executable.fbs @@ -0,0 +1,478 @@ +// IDL file for DarwiNN Executable. + +namespace platforms.darwinn; + + +// A new file identifier should only be introduced if a different schema, with +// probably a different root node, is needed. This shall be a very rare case. +file_identifier "DWN1"; + +enum Description : short { + // Bundle::Alu::MOVI instruction to load output activation base address. + BASE_ADDRESS_OUTPUT_ACTIVATION = 0, + + // Bundle::Alu::MOVI instruction to load input activation base address. + BASE_ADDRESS_INPUT_ACTIVATION = 1, + + // Bundle::Alu::MOVI instruction to load parameter base address. + BASE_ADDRESS_PARAMETER = 2, + + // Bundle::Alu::MOVI instruction to load scratch buffer base address. + BASE_ADDRESS_SCRATCH = 3, +} + +enum Position : short { + // Lower 32-bit of 64-bit address. + LOWER_32BIT = 0, + + // Upper 32-bit of 64-bit address. + UPPER_32BIT = 1, +} + +// Linker metadata. Enums for various special fields in the encoded instruction +// stream that will be populated by the driver at run time. +table Meta { + // Indicates which base address this metadata is targeting. + desc:Description; + + // For input/output/scratch, provides batch information. + // Parameter will not contain batch. + batch:int; + + // Name of the input/output layer for input/output activations. Parameter and + // scratch should not have this field. + name:string; + + // Tells which bit position to update. + position:Position; +} + +// Holds offset information of a field in an instruction bit stream chunk. +table FieldOffset { + // Linker metadata. + meta:Meta; + + // Bit offset. + offset_bit:int; +} + +// Holds information for an instruction bitstream chunk. +table InstructionBitstream { + // Encoded bitstream for a real hardware. + bitstream:[ubyte]; + + // Offset (in bits) of various fields in the instruction bit stream. These + // fields are filled in by the driver before sending the instruction stream + // to the hardware. + field_offsets:[FieldOffset]; +} + + +// Represents interrupt coming through descriptor path. +enum InterruptType : short { + // Scalar core supports 4 interrupts. + SCALAR_CORE_INT_0 = 0, + SCALAR_CORE_INT_1 = 1, + SCALAR_CORE_INT_2 = 2, + SCALAR_CORE_INT_3 = 3, +} + +// Represents direction of DMA. +enum Direction : short { + // From host to device. + INFEED = 0, + + // From device to host. + OUTFEED = 1, +} + +// Holds DMA hint information for DMA descriptors. +table DmaDescriptorHint { + // Metadata to indicate the DMA descriptor. + meta:Meta; + + // Since base address is determined at link time, byte offset from base + // address is recorded here. + offset_in_bytes:int; + + // Number of bytes to be transferred for this hint. + size_in_bytes:int; +} + +// Holds interrupt hint information. +table InterruptHint { + type:InterruptType; +} + +// Holds Instuction hint information. +table InstructionHint { + // Instruction chunk. Whole instruction chunk is always transferred. + instruction_chunk_index:int; +} + +// Holds fence hint. Fence enforces that all DMA hints before Fence should be +// processed completely before processing any DMA hints after the Fence. +table FenceHint { +} + +// A hint can be any one of the following. +union AnyHint { + DmaDescriptorHint, + InstructionHint, + InterruptHint, + FenceHint, +} + +// Hints deterministic DMA. +table DmaHint { + any_hint:AnyHint; + + // Direction of DMA. + direction:Direction; +} + +// A complete collection of DMA hints for either input or output. +table DmaHints { + // Series of hints. + hints:[DmaHint]; + + // True if "hints" cover all the DMAs in the model. + fully_deterministic:bool; +} + +// A group of simple int->int map that helps us to translate a user-visible +// coordinate value to hardware-friendly data layout for the final output +// activation. +// +// Note that this is needed only for 3D output. 1D output, this field will not +// be used and a user is not supposed to use this function. +// +// +// Let's use an example when we have 2x2 tiles and we want to produce 4x5x32 +// output tensor (y/x/z order). +// +// In this example, tile0 and tile2 will produce a 2x3x32 tensor and tile1 and +// tile3 will produce a 2x2x32 tensor. +// +// +--------+--------+ +// | Tile0 | Tile1 | +// | 2x3x32 | 2x2x32 | +// +--------+--------+ +// | Tile2 | Tile3 | +// | 2x3x32 | 2x2x32 | +// +--------+--------+ +// +// y_coordinate_to_linear_tile_id_map will be (0, 0, 2, 2), encoding the +// linearized tile ID of the first tile of a row that a target y value will be +// stored. +// +// x_coordinate_to_linear_tile_id_map will be (0, 0, 0, 1, 1), encoding the +// X tile ID of a tile that will hold corresponding x value. +// +// linearized_tile_byte_offset will be (0, 192, 320, 512) encoding the starting +// byte offset of output of each tile when we fully linearize output. +// +// x_coordinate_to_local_byte_offset will be (0, 32, 64, 0, 32) as byte +// offset, encoding byte offset for each local x offset. +// +// y_coordinate_to_local_y_offset will be (0, 1, 0, 1) as y offset for +// y=0 will be 0 in each tile while that for y=1 will be 1. +// +// x_coordinate_to_local_y_row_size will be (3*32, 3*32, 3*32, 2*32, 2*32) as +// each y-row for Tile0/2 is 3*32 bytes and that for Tile1/3 is 2*32 bytes. +table OutputLayout { + // Holds a map from a tensor Y coordinate value to the linearized ID of the + // first tile of rows that produces output values for a given Y coordinate. + y_coordinate_to_linear_tile_id_map:[int]; + + // Holds a map for a given x coordinate value to tile ID within a row of + // tiles. + x_coordinate_to_linear_tile_id_map:[int]; + + // Holds an accumulated offset value for each tile. + linearized_tile_byte_offset:[int]; + + // Holds a map from a tensor x coordinate to local byte offset within each + // tile. + x_coordinate_to_local_byte_offset:[int]; + + // Holds a map from a tensor y coordinate to local y offset within each tile. + y_coordinate_to_local_y_offset:[int]; + + // Holds a map from a tensor x coordinate to local y row size within each + // tile. + x_coordinate_to_local_y_row_size:[int]; +} + +// Inclusive range of numbers. +struct Range { + start:int; + end:int; +} + +// Tensor shape +table TensorShape { + // List of inclusive index range (start, end) of each dimension. + dimension:[Range]; +} + +// Tensor layout describes how tensor elements are stored in a linear memory +// space. See details in go/darwinn-output-layout. +table TensorLayout { + // Tensor shape stored in this layout. + shape:TensorShape; + + // Distance (in number of elements) between two adjacent elements in each + // dimension. + stride:[int]; +} + +// Represents output tensor shape of each tile. This information will be used +// for re-layout in the host. +table OutputShapeInfo { + // The final model output is transferred to the host in a list of tensor + // slices (sub-tensors). A slice is a collection of elements that can be + // represented as a single tensor shape and tensor layout. + slice_layout:[TensorLayout]; + + // Base offset (in bytes) of the first element in the layout. + slice_offset:[int]; +} + +// Numerics-related constant values needed for interpreting output tensor. +table NumericsConstants { + zero_point:int; + dequantization_factor:float; +} + +// //depot/google3/third_party/darwinn/api/runtime_version.h:runtime_version, +// //depot/google3/platforms/darwinn/driver/test_data/backward_compatibility/BUILD:test_cases) + +// Layer data type information. +// Note: The DataType enum should be synced with +// platforms/darwinn/model/config/array.proto. +enum DataType : short { + // Unsigned fixed point (it would be more appropriate to call this an affine + // value) means there is a scale and zero point associated with this tensor, + // To transform unsigned fixed-point values to real values: + // real_value = (unsigned_fixed-point_value - zero_point) * scale + FIXED_POINT8 = 0, + FIXED_POINT16 = 1, + // SIGNED_FIXED_POINT32 is a signed fixed point but is given an enum value + // of 2 due to historical reason. Please see the below for documentation of + // signed fixed-point types. + SIGNED_FIXED_POINT32 = 2, + + // BFLOAT is Google’s own floating point format, with 8 bit exponent and 8 bit + // significand (7 bit stored significand). + BFLOAT = 3, + // HALF is industry standard IEEE 754-2008 binary16, with 5 bit exponent and + // 11 bit significand (10 bit stored significand). + HALF = 4, + // SINGLE is industry standard IEEE 754-2008 binary32, with 8 bit exponent and + // 24 bit significant (23 bit stored signficand). + SINGLE = 5, + + // Signed fixed point data types. Number is stored in two's complement format. + // There is an associated scale but no zero point. To transform fixed-point + // values to real values: + // real_value = signed_fixedpoint_value * scale + SIGNED_FIXED_POINT8 = 8, + SIGNED_FIXED_POINT16 = 9, +} + +// //depot/google3/third_party/darwinn/api/runtime_version.h:runtime_version, +// //depot/google3/platforms/darwinn/driver/test_data/backward_compatibility/BUILD:test_cases) + +// Output layer specific information. +table OutputLayer { + // Encapsulates information needed to transform a multi-dimensional output + // tensor to its original YXZ layout. This field must be set for any tensor + // with x_dim and y_dim more than 1. + layout:OutputLayout; + data_type:DataType; // deprecated + + // Output shape information that is streamed from the tiles. + shape_info:OutputShapeInfo; +} + +// Input layer specific information. +table InputLayer { +} + +// One of output or input layer. +union AnyLayer { + OutputLayer, + InputLayer, +} + +// Layer information. +table Layer { + // Name of the corresponding input/output layer. + name:string; + + // Size in bytes, including padding. This number is for batch_size=1. The + // unpadded byte size of a tensor is: + // x_dim * y_dim * z_dim * bytes_per_data_type. + size_bytes:int; + + // Dimension info. All these fields should be set for input and output + // tensors. ?_dim=1 means we don't have ? dimension. For example, in a single + // dimensional tensor x_dim=1, y_dim=1, z_dim=N. + y_dim:int; + x_dim:int; + z_dim:int; + + // Numerics constants used for dequantization and quantization. + numerics:NumericsConstants; + + // For input layer, this is the data type of input, for output layer, this is the data type of output. + data_type:DataType; + + // Input or Output Layer specific information. + any_layer:AnyLayer; + + // How many times this layer will get executed per inference. Default is 1. + // This information will be used to create large enough buffer to host inputs + // and outputs for layers that will get executed several times per inference. + execution_count_per_inference:int = 1; + + // If set, the activations on this layer will be cached on TPU DRAM (if DRAM + // is available and there is enough free space on it). + cache_on_dram:bool = false; + + // Tensor shape info. + shape:TensorShape; +} + +// Specifies the nature of an executable. +enum ExecutableType : short { + // Everything needed to run a successful inference is included. + STAND_ALONE = 0, + + // Only loads parameters into TPU memory. This type of executable should + // always accompany at least 1 EXECUTION_ONLY executable in the same package. + PARAMETER_CACHING = 1, + + // This type of executable assumes the parameters are already cached on TPU. + // This type should always be accompanied by a PARAMETER_CACHING executable in + // the same package. + EXECUTION_ONLY = 2, +} + +table Executable { + // Executable format version. Set to 0 for now. + version:int = 0; + + // Model name. + name:string; + + // Model protobuf in binary serialized format. + serialized_model:[ubyte]; + + // Batch size. That is the number of inputs that can be simultaneously + // processed. + batch_size:int; + + // Size in bytes of the scratch buffer expected for this model. + // This number is for batch_size=1. + scratch_size_bytes:int; + + // Encoded instruction bitstreams. + instruction_bitstreams:[InstructionBitstream]; + + // Parameter stream. This field must be guaranteed to be aligned by the code + // that produces the flat buffer. As of now, executable_converter ensures + // this. + parameters:[ubyte]; + + // Dma Hints. + dma_hints:DmaHints; + + // Input layer Information + input_layers:[Layer]; + + // Output layer Information. + output_layers:[Layer]; + + // Chip that the executable was compiled for. + chip:string; + + // Deprecated. Use estimated_cycles_64bit below instead. + estimated_cycles:int; + + // The maximum amount of narrow memory bytes that is guaranteed to be used per + // tile. All narrow memory used in a tile is guaranteed to be at byte + // addresses below this value. + used_narrow_memory_bytes_per_tile:int; + + // Type of this executable. If not specified, runtime assumes STAND_ALONE. + type:ExecutableType; + + // Parameter-caching executables with the same token can cache their + // parameters together on the TPU SRAM. + parameter_caching_token:uint64; + + // If set, parameters in this model will be loaded in the TPU DRAM for higher + // performance. TPU DRAM is available on some architectures (e.g. Noronha). + // TPU DRAM is a scarce resource, therefore only selected models can have this + // option enabled (e.g. RNN-T for Noronha). If this option is enabled and + // enough TPU DRAM is not available an error is returned at run time. + use_tpu_dram_for_parameters:bool = false; + + // Estimated runtime in cycles for this model. + estimated_cycles_64bit:int64; +} + +// MultiExecutable encapsulates one or more DarwiNN serialized executables that +// are all part of the same package. +table MultiExecutable { + serialized_executables:[string]; +} + +// Serialized package allows individual packages to stay page-aligned +// relative to beginning of the byte array. +table SerializedPackage { + serialized_package:[ubyte] (nested_flatbuffer: "Package"); +} + +// The collection of executables, signature and everything else that is needed +// for DarwiNN runtime to run one or more models that are related. +table Package { + // Minimum runtime version needed to process this package correctly. + min_runtime_version:int; + + // A serialized MultiExecutable. + serialized_multi_executable:[ubyte]; + + // Signature of serialized_multi_executable. + signature:[ubyte]; + + // The version of this package to identify assumptions on the structure. + keypair_version:int; + + // Specifies the version of DarwiNN compiler used to create this package. + compiler_version:string; + + // Chip ID in the virtual cluster to execute these graphs. + // 0 if this package is compiled to run on a single chip. + // -1 if this is a multiple-chip package. + virtual_chip_id:int = 0; + + // Package data for individual chip to execute. + // Note that the package data is not aligned in package bundle file, but it + // will be loaded into aligned memory block at model registration. + // An intermediate table SerializedPackage is needed, for flatbuffer only + // supports 1-d vector. + // TODO: Consider creating a new root type for new chips. + multi_chip_package:[SerializedPackage]; + + // A user-specified identifier. This is for limited use of offline compiled + // models. + model_identifier:string; +} + +root_type Package; + +// //depot/google3/third_party/darwinn/api/runtime_version.h:runtime_version, +// //depot/google3/platforms/darwinn/driver/test_data/backward_compatibility/BUILD:test_cases) diff --git a/port/BUILD b/port/BUILD new file mode 100644 index 0000000..1bceb41 --- /dev/null +++ b/port/BUILD @@ -0,0 +1,434 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Port of various google3 libraries and utilities. In google3, blaze allows +# building two variants controlled by the --define=darwinn_portable flag. +# --define=darwinn_portable=0 (default) : +# Port directory points to google3. (DARWINN_PORT_GOOGLE3 is defined.) +# --define=darwinn_portable=1 : +# Port directory points to default/ (DARWINN_PORT_DEFAULT is defined.) +# Android: +# For compiling against Android NDK (with google3 blaze) both above configurations +# will work. +# For compiling as part of Android system however, DARWINN_PORT_ANDROID +# must be defined as part of the Android build system. + +load("//:build_defs.bzl", "darwinn_port_defines") + +package(default_visibility = ["//visibility:public"]) + +# All Google Owned Code except : +# - certain files in port/default/ that are under Apache 2.0 license. +licenses(["notice"]) + +cc_library( + name = "port", + hdrs = [ + "aligned_malloc.h", + "array_slice.h", + "builddata.h", + "casts.h", + "cleanup.h", + "defs.h", + "errors.h", + "gflags.h", + "integral_types.h", + "logging.h", + "macros.h", + "math_util.h", + "openssl.h", + "ptr_util.h", + "status.h", + "status_macros.h", + "statusor.h", + "stringprintf.h", + "time.h", + "unreachable.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_portable": [ + "//port/default:port", + ], + "//conditions:default": [ + "//base", + "//util/gtl:ptr_util", + "//util/math:mathutil", + "//util/task:status", + "//util/task:statusor", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], + }) + [ + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:flag", + "//port/default:cleanup", + "//port/default:unreachable", + ], +) + +cc_library( + name = "unreachable", + hdrs = [ + "unreachable.h", + ], + defines = darwinn_port_defines(), + deps = [ + "//port/default:unreachable", + ], +) + +cc_library( + name = "macros", + hdrs = [ + "defs.h", + "macros.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_portable": [ + "//port/default:port", + ], + "//:darwinn_firmware": [ + "//port/default:port", + ], + "//conditions:default": [ + "//base", + ], + }), +) + +cc_library( + name = "logging", + hdrs = [ + "defs.h", + "logging.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_portable": [ + "//port/default:port", + ], + "//:darwinn_firmware": [ + "//port/firmware:logging", + ], + "//conditions:default": [ + "//base", + ], + }), +) + +# This library is used for Darwinn 1.0 runtime only (Android/Google3). +cc_library( + name = "thread_annotations", + hdrs = [ + "defs.h", + "thread_annotations.h", + ], + defines = darwinn_port_defines(), + deps = [ + "//port/default:thread_annotations", + ], +) + +cc_library( + name = "integral_types", + hdrs = [ + "defs.h", + "integral_types.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_portable": [ + "//port/default:port", + ], + "//:darwinn_firmware": [ + "//port/default:port", + ], + "//conditions:default": [ + "//base", + ], + }), +) + +cc_library( + name = "std_mutex_lock", + hdrs = ["std_mutex_lock.h"], + deps = [ + ":port", + "//port:thread_annotations", + ], +) + +cc_library( + name = "shared_mutex", + srcs = [ + "shared_mutex.cc", + ], + hdrs = ["shared_mutex.h"], + deps = [ + "port", + "//port:thread_annotations", + ], +) + +cc_library( + name = "mutex", + hdrs = [ + "defs.h", + "mutex.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_portable": [ + "//port/default:port", + "//port:thread_annotations", + ], + "//:darwinn_firmware": [ + "//third_party/safertos_addons", + ], + "//conditions:default": [ + "@com_google_absl//absl/synchronization", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "dma330", + hdrs = [ + "defs.h", + "dma330.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": [ + "//firmware/driver/dma:dma330", + ], + "//conditions:default": [ + "//port/default:dma330", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "semaphore", + hdrs = [ + "defs.h", + "semaphore.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": [ + "//third_party/safertos_addons", + ], + "//conditions:default": [ + "//port/default:port", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "condition_variable", + hdrs = [ + "condition_variable.h", + "defs.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": [ + "//firmware/common:condition_variable", + ], + "//conditions:default": [ + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "restartable_thread", + hdrs = ["restartable_thread.h"], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": ["//port/firmware:restartable_thread"], + "//conditions:default": ["//port/default:restartable_thread"], + }), + alwayslink = 1, +) + +cc_library( + name = "blocking_queue", + hdrs = ["blocking_queue.h"], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": ["//port/firmware:blocking_queue"], + "//conditions:default": ["//port/default:blocking_queue"], + }), + alwayslink = 1, +) + +cc_library( + name = "lock_guard", + srcs = ["lock_guard.cc"], + hdrs = ["lock_guard.h"], + defines = darwinn_port_defines(), + deps = [ + "//port:mutex", + "//port:thread_annotations_firmware", + ], + alwayslink = 1, +) + +cc_library( + name = "blocking_counter", + srcs = ["blocking_counter.cc"], + hdrs = ["blocking_counter.h"], + deps = [ + ":port", + ":std_mutex_lock", + "//port:thread_annotations", + ], +) + +# If --define darwinn_xprof_enabled=1, links with xprof libraries. +config_setting( + name = "darwinn_xprof_enabled", + values = { + "define": "darwinn_xprof_enabled=1", + }, +) + +# If --define darwinn_android_google3_trace_enabled=1, enable traces for +# android binaries built in google3. +config_setting( + name = "darwinn_android_google3_trace_enabled", + values = { + "define": "darwinn_android_google3_trace_enabled=1", + }, +) + +# If --define darwinn_csv_trace_enabled=1, generates trace outputs in csv. +config_setting( + name = "darwinn_csv_trace_enabled", + values = { + "define": "darwinn_csv_trace_enabled=1", + }, +) + +# If --define darwinn_perfetto_trace_enabled=1, generates trace outputs in perfetto proto file. +config_setting( + name = "darwinn_perfetto_trace_enabled", + values = { + "define": "darwinn_perfetto_trace_enabled=1", + }, +) + +# If --define pnp_benchmarking=1, control what trace points are generated to minimize intrusion. +# (used by Silicon Software PnP team) +config_setting( + name = "pnp_benchmarking", + values = { + "define": "pnp_benchmarking=1", + }, +) + +cc_library( + name = "tracing", + hdrs = ["tracing.h"], + defines = darwinn_port_defines(), +) + +cc_library( + name = "string_util", + hdrs = ["string_util.h"], + deps = [ + "//port/default:strcat", + ], +) + +cc_library( + name = "timer", + srcs = select({ + "//:windows": ["timer_windows.cc"], + "//:darwin": ["timer_darwin.cc"], + "//conditions:default": ["timer_linux.cc"], + }), + hdrs = ["timer.h"] + select({ + "//:windows": ["timer_windows.h"], + "//:darwin": ["timer_darwin.h"], + "//conditions:default": ["timer_linux.h"], + }), + deps = [ + ":port", + ], +) + +cc_library( + name = "posix_time", + srcs = ["posix_time.cc"], + hdrs = ["posix_time.h"], + deps = [ + ":integral_types", + ], +) + +cc_library( + name = "cpu", + hdrs = ["cpu.h"], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": ["//port/firmware:cpu"], + "//conditions:default": ["//port/default:cpu"], + }), +) + +cc_library( + name = "demangle", + hdrs = [ + "defs.h", + "demangle.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//conditions:default": [ + "//base", + ], + "//:darwinn_portable": [ + ], + "//:darwinn_firmware": [ + ], + }), +) + +cc_library( + name = "bsp", + hdrs = [ + "bsp.h", + "defs.h", + ], + defines = darwinn_port_defines(), + deps = select({ + "//:darwinn_firmware": [ + "//firmware:bsp", + ], + "//conditions:default": [ + ], + }), + alwayslink = 1, +) diff --git a/port/aligned_malloc.h b/port/aligned_malloc.h new file mode 100644 index 0000000..4d7e4c0 --- /dev/null +++ b/port/aligned_malloc.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_ALIGNED_MALLOC_H_ +#define DARWINN_PORT_ALIGNED_MALLOC_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "base/port.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/aligned_malloc.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_ALIGNED_MALLOC_H_ diff --git a/port/array_slice.h b/port/array_slice.h new file mode 100644 index 0000000..39c2547 --- /dev/null +++ b/port/array_slice.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_ARRAY_SLICE_H_ +#define DARWINN_PORT_ARRAY_SLICE_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "absl/types/span.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/array_slice.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_ARRAY_SLICE_H_ diff --git a/port/blocking_counter.cc b/port/blocking_counter.cc new file mode 100644 index 0000000..8cc276b --- /dev/null +++ b/port/blocking_counter.cc @@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/blocking_counter.h" + +#include // NOLINT +#include // NOLINT + +#include "port/logging.h" +#include "port/std_mutex_lock.h" + +namespace platforms { +namespace darwinn { + +bool BlockingCounter::DecrementCount() { + StdMutexLock lock(&mutex_); + count_--; + if (count_ < 0) { + LOG(FATAL) << "BlockingCounter::DecrementCount() called too many times."; + } + + if (count_ == 0) { + cv_.notify_all(); + return true; + } + return false; +} + +void BlockingCounter::Wait() { + StdCondMutexLock lock(&mutex_); + cv_.wait(lock, [this] { return count_ == 0; }); +} + +} // namespace darwinn +} // namespace platforms diff --git a/port/blocking_counter.h b/port/blocking_counter.h new file mode 100644 index 0000000..73c2973 --- /dev/null +++ b/port/blocking_counter.h @@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_BLOCKING_COUNTER_H_ +#define DARWINN_PORT_BLOCKING_COUNTER_H_ + +#include // NOLINT +#include // NOLINT + +#include "port/std_mutex_lock.h" +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { + +// This class allows a thread to block for a pre-specified number of actions. +// Based on absl implementation, execpt that std::mutex is used as opposed to +// absl::Mutex. +// TODO: Remove this implementation when we fully migrate to absl. +// See: cs/absl/synchronization/blocking_counter.h +class BlockingCounter { + public: + explicit BlockingCounter(int initial_count) : count_(initial_count) {} + + BlockingCounter(const BlockingCounter&) = delete; + BlockingCounter& operator=(const BlockingCounter&) = delete; + + // BlockingCounter::DecrementCount() + // + // Decrements the counter's "count" by one, and return "count == 0". This + // function requires that "count != 0" when it is called. + // + // Memory ordering: For any threads X and Y, any action taken by X + // before it calls `DecrementCount()` is visible to thread Y after + // Y's call to `DecrementCount()`, provided Y's call returns `true`. + bool DecrementCount(); + + // BlockingCounter::Wait() + // + // Blocks until the counter reaches zero. This function may be called at most + // once. On return, `DecrementCount()` will have been called "initial_count" + // times and the blocking counter may be destroyed. + // + // Memory ordering: For any threads X and Y, any action taken by X + // before X calls `DecrementCount()` is visible to Y after Y returns + // from `Wait()`. + void Wait(); + + private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_BLOCKING_COUNTER_H_ diff --git a/port/builddata.h b/port/builddata.h new file mode 100644 index 0000000..d8a57b2 --- /dev/null +++ b/port/builddata.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_BUILDDATA_H_ +#define DARWINN_PORT_BUILDDATA_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "absl/base/builddata.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/builddata.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_BUILDDATA_H_ diff --git a/port/casts.h b/port/casts.h new file mode 100644 index 0000000..3beadba --- /dev/null +++ b/port/casts.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_CASTS_H_ +#define DARWINN_PORT_CASTS_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "base/casts.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/casts.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_CASTS_H_ diff --git a/port/cleanup.h b/port/cleanup.h new file mode 100644 index 0000000..27eece4 --- /dev/null +++ b/port/cleanup.h @@ -0,0 +1,22 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_CLEANUP_H_ +#define DARWINN_PORT_CLEANUP_H_ + +#include "port/defs.h" + +#include "port/default/cleanup.h" + +#endif // DARWINN_PORT_CLEANUP_H_ diff --git a/port/default/BUILD b/port/default/BUILD new file mode 100644 index 0000000..2987cec --- /dev/null +++ b/port/default/BUILD @@ -0,0 +1,148 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Port of various google3 libraries and utilities. + +package(default_visibility = ["//visibility:public"]) + +# All Google Owned Code except : +# - certain files in port/default/ that are under Apache 2.0 license. +licenses(["notice"]) + +# Independent headers that can be included by port_from_tf. +cc_library( + name = "port_base", + hdrs = [ + "error_codes.h", + "strcat.h", + "unreachable.h", + ] + select({ + "//:windows": ["unreachable_windows.h"], + "//conditions:default": ["unreachable_default.h"], + }), +) + +# Port of various google3 libraries and utilities. +cc_library( + name = "port", + srcs = [ + "status_macros.cc", + "stringprintf.cc", + ], + hdrs = [ + "aligned_malloc.h", + "array_slice.h", + "builddata.h", + "casts.h", + "error_codes.h", + "errors.h", + "integral_types.h", + "logging.h", + "macros.h", + "mutex.h", + "math_util.h", + "ptr_util.h", + "semaphore.h", + "status.h", + "status_macros.h", + "statusor.h", + "strcat.h", + "stringprintf.h", + ] + select({ + "//:windows": ["aligned_malloc_windows.h"], + "//conditions:default": ["aligned_malloc_default.h"], + }), + deps = [ + ":thread_annotations", + "//port/default/port_from_tf", + ], +) + +cc_library( + name = "thread_annotations", + hdrs = [ + "thread_annotations.h", + ], + deps = [ + "//port/default/port_from_tf:thread_annotations", + ], +) + +cc_library( + name = "strcat", + hdrs = [ + "strcat.h", + ], +) + +cc_library( + name = "unreachable", + hdrs = [ + "unreachable.h", + ] + select({ + "//:windows": ["unreachable_windows.h"], + "//conditions:default": ["unreachable_default.h"], + }), +) + +cc_library( + name = "cleanup", + hdrs = ["cleanup.h"], + deps = [ + ":port", + ], +) + +cc_library( + name = "cpu", + hdrs = ["cpu.h"], +) + +cc_library( + name = "restartable_thread", + hdrs = ["restartable_thread.h"], + deps = [ + "//firmware/common:log", + "//firmware/common/callback", + "//firmware/os:restartable_thread_interface", + ], +) + +cc_library( + name = "blocking_queue", + hdrs = ["blocking_queue.h"], + deps = [ + "//firmware/common:log", + "//firmware/datastruct:circular_queue", + "//firmware/os:queue_interface", + ], +) + +cc_library( + name = "memory_barriers", + hdrs = ["memory_barriers.h"], +) + +cc_library( + name = "dma330", + hdrs = ["dma330.h"], + deps = [ + ":port", + "//firmware/common:iomem", + "//firmware/common:status", + "//firmware/common/callback", + "//firmware/driver/dma:dma330_interface", + ], +) diff --git a/port/default/aligned_malloc.h b/port/default/aligned_malloc.h new file mode 100644 index 0000000..4e01196 --- /dev/null +++ b/port/default/aligned_malloc.h @@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (C) 1999 and onwards Google, Inc. +// +// Various portability macros, type definitions, and inline functions +// This file is used for both C and C++! +// +// These are weird things we need to do to get this compiling on +// random systems (and on SWIG). +// +// MOE:begin_strip +// This file is open source. You may export it with your open source projects +// as long as you use MOE to strip proprietary comments. +// MOE:end_strip + +#ifndef DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_H_ +#define DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_H_ + +#if defined(_WIN32) +#include "port/default/aligned_malloc_windows.h" +#else +#include "port/default/aligned_malloc_default.h" +#endif + +#endif // DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_H_ diff --git a/port/default/aligned_malloc_default.h b/port/default/aligned_malloc_default.h new file mode 100644 index 0000000..67e36e9 --- /dev/null +++ b/port/default/aligned_malloc_default.h @@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_DEFAULT_H_ +#define DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_DEFAULT_H_ + +#include + +namespace platforms { +namespace darwinn { + +inline void *aligned_malloc(size_t size, int minimum_alignment) { + void* ptr; + if (posix_memalign(&ptr, minimum_alignment, size) == 0) + return ptr; + return nullptr; +} + +inline void aligned_free(void *aligned_memory) { + free(aligned_memory); +} + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_DEFAULT_H_ diff --git a/port/default/aligned_malloc_windows.h b/port/default/aligned_malloc_windows.h new file mode 100644 index 0000000..ceaed4a --- /dev/null +++ b/port/default/aligned_malloc_windows.h @@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_WINDOWS_H_ +#define DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_WINDOWS_H_ + +#include // for _aligned_{malloc/free}() + +namespace platforms { +namespace darwinn { + +inline void *aligned_malloc(size_t size, int minimum_alignment) { + return _aligned_malloc(size, minimum_alignment); +} + +inline void aligned_free(void *aligned_memory) { + _aligned_free(aligned_memory); +} + +} // namespace darwinn +} // namespace platforms + + +#endif // DARWINN_PORT_DEFAULT_ALIGNED_MALLOC_WINDOWS_H_ diff --git a/port/default/array_slice.h b/port/default/array_slice.h new file mode 100644 index 0000000..656ac8f --- /dev/null +++ b/port/default/array_slice.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_ARRAY_SLICE_H_ +#define DARWINN_PORT_DEFAULT_ARRAY_SLICE_H_ + +#include "port/default/port_from_tf/array_slice.h" + +#endif // DARWINN_PORT_DEFAULT_ARRAY_SLICE_H_ diff --git a/port/default/builddata.h b/port/default/builddata.h new file mode 100644 index 0000000..335b17c --- /dev/null +++ b/port/default/builddata.h @@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_BUILDDATA_H_ +#define DARWINN_PORT_DEFAULT_BUILDDATA_H_ + +#if defined(_WIN32) +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) +#define COMPILER_VERSION "MSVC " TOSTRING(_MSC_FULL_VER) +#elif defined(__GNUC__) +#define COMPILER_VERSION __VERSION__ +#else +#define COMPILER_VERSION "Unknown" +#endif + +struct BuildData { + static const char* BuildLabel() { + return "COMPILER=" COMPILER_VERSION ",DATE=" __DATE__ ",TIME=" __TIME__ ",CL_NUMBER=315372822"; + } +}; + +#endif // DARWINN_PORT_DEFAULT_BUILDDATA_H_ diff --git a/port/default/casts.h b/port/default/casts.h new file mode 100644 index 0000000..050e5ec --- /dev/null +++ b/port/default/casts.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_CASTS_H_ +#define DARWINN_PORT_DEFAULT_CASTS_H_ + +#include "port/default/port_from_tf/casts.h" + +#endif // DARWINN_PORT_DEFAULT_CASTS_H_ diff --git a/port/default/cleanup.h b/port/default/cleanup.h new file mode 100644 index 0000000..057f73d --- /dev/null +++ b/port/default/cleanup.h @@ -0,0 +1,117 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() {} +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = absl::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup. +// +// Style exception for rvalue references: CL/62979921. +// No NOLINT qualifiers, since cpplint's handling of && is imperfect. + +#ifndef DARWINN_PORT_DEFAULT_CLEANUP_H_ +#define DARWINN_PORT_DEFAULT_CLEANUP_H_ + +#include +#include + +#include "port/default/macros.h" + +namespace platforms { +namespace darwinn { + +// A move-only RAII object that calls a stored cleanup functor when destroyed. +// Cleanup is the return type of absl::MakeCleanup(F). +template +class Cleanup { + public: + Cleanup() + : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) { } + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) { } + + // Assignment to a Cleanup object behaves like destroying it and making a new + // one in its place, analogous to unique_ptr semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference(), "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +ABSL_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { + static_assert(sizeof...(ExplicitParameterBarrier) == 0, + "No explicit template arguments."); + return Cleanup(std::forward(f)); +} + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_CLEANUP_H_ diff --git a/port/default/error_codes.h b/port/default/error_codes.h new file mode 100644 index 0000000..55b4e3e --- /dev/null +++ b/port/default/error_codes.h @@ -0,0 +1,146 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_ERROR_CODES_H_ +#define DARWINN_PORT_DEFAULT_ERROR_CODES_H_ + +namespace platforms { +namespace darwinn { +namespace util { +namespace error { + +// The canonical error codes. +enum Code { + // Not an error; returned on success + OK = 0, + + // The operation was cancelled (typically by the caller). + CANCELLED = 1, + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2, + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3, + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4, + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5, + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6, + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7, + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16, + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8, + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9, + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10, + + // Operation tried to iterate past the valid input range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11, + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12, + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13, + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14, + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15, +}; + +} // namespace error +} // namespace util +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_ERROR_CODES_H_ diff --git a/port/default/errors.h b/port/default/errors.h new file mode 100644 index 0000000..5775205 --- /dev/null +++ b/port/default/errors.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_ERRORS_H_ +#define DARWINN_PORT_DEFAULT_ERRORS_H_ + +#include "port/default/port_from_tf/errors.h" + +#endif // DARWINN_PORT_DEFAULT_ERRORS_H_ diff --git a/port/default/integral_types.h b/port/default/integral_types.h new file mode 100644 index 0000000..ad0a525 --- /dev/null +++ b/port/default/integral_types.h @@ -0,0 +1,90 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2010 Google Inc. All Rights Reserved. +// +// Basic integer type definitions for various platforms +// +// This code is compiled directly on many platforms, including client +// platforms like Windows, Mac, and embedded systems. Before making +// any changes here, make sure that you're not breaking any platforms. + +#ifndef DARWINN_PORT_DEFAULT_INTEGRAL_TYPES_H_ +#define DARWINN_PORT_DEFAULT_INTEGRAL_TYPES_H_ + +namespace platforms { +namespace darwinn { + +// Standard typedefs +// Signed integer types with width of exactly 8, 16, 32, or 64 bits +// respectively, for use when exact sizes are required. +typedef signed char schar; +typedef signed char int8; +typedef short int16; +typedef int int32; +typedef long long int64; + +// NOTE: unsigned types are DANGEROUS in loops and other arithmetical +// places. Use the signed types unless your variable represents a bit +// pattern (eg a hash value) or you really need the extra bit. Do NOT +// use 'unsigned' to express "this value should always be positive"; +// use assertions for this. + +// Unsigned integer types with width of exactly 8, 16, 32, or 64 bits +// respectively, for use when exact sizes are required. +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long long uint64; + +// A type to represent a Unicode code-point value. As of Unicode 4.0, +// such values require up to 21 bits. +// (For type-checking on pointers, make this explicitly signed, +// and it should always be the signed version of whatever int32 is.) +typedef signed int char32; + +// A type to represent a natural machine word (for e.g. efficiently +// scanning through memory for checksums or index searching). Don't use +// this for storing normal integers. Ideally this would be just +// unsigned int, but our 64-bit architectures use the LP64 model +// (http://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models), hence +// their ints are only 32 bits. We want to use the same fundamental +// type on all archs if possible to preserve *printf() compatability. +typedef unsigned long uword_t; + +#define GG_LONGLONG(x) x##LL +#define GG_ULONGLONG(x) x##ULL +#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. +#define GG_LL_FORMAT_W L"ll" + +static const uint8 kuint8max = static_cast(0xFF); +static const uint16 kuint16max = static_cast(0xFFFF); +static const uint32 kuint32max = static_cast(0xFFFFFFFF); +static const uint64 kuint64max = + static_cast(GG_LONGLONG(0xFFFFFFFFFFFFFFFF)); +static const int8 kint8min = static_cast(~0x7F); +static const int8 kint8max = static_cast(0x7F); +static const int16 kint16min = static_cast(~0x7FFF); +static const int16 kint16max = static_cast(0x7FFF); +static const int32 kint32min = static_cast(~0x7FFFFFFF); +static const int32 kint32max = static_cast(0x7FFFFFFF); +static const int64 kint64min = + static_cast(GG_LONGLONG(~0x7FFFFFFFFFFFFFFF)); +static const int64 kint64max = + static_cast(GG_LONGLONG(0x7FFFFFFFFFFFFFFF)); + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_INTEGRAL_TYPES_H_ diff --git a/port/default/logging.h b/port/default/logging.h new file mode 100644 index 0000000..62179e8 --- /dev/null +++ b/port/default/logging.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_LOGGING_H_ +#define DARWINN_PORT_DEFAULT_LOGGING_H_ + +#include "port/default/port_from_tf/logging.h" + +#endif // DARWINN_PORT_DEFAULT_LOGGING_H_ diff --git a/port/default/macros.h b/port/default/macros.h new file mode 100644 index 0000000..3e28417 --- /dev/null +++ b/port/default/macros.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_MACROS_H_ +#define DARWINN_PORT_DEFAULT_MACROS_H_ + +#include "port/default/port_from_tf/macros.h" + +#endif // DARWINN_PORT_DEFAULT_MACROS_H_ diff --git a/port/default/math_util.h b/port/default/math_util.h new file mode 100644 index 0000000..a55b8f8 --- /dev/null +++ b/port/default/math_util.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_MATH_UTIL_H_ +#define DARWINN_PORT_DEFAULT_MATH_UTIL_H_ + +#include "port/default/port_from_tf/math_util.h" + +#endif // DARWINN_PORT_DEFAULT_MATH_UTIL_H_ diff --git a/port/default/memory_barriers.h b/port/default/memory_barriers.h new file mode 100644 index 0000000..2f5a527 --- /dev/null +++ b/port/default/memory_barriers.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_MEMORY_BARRIERS_H_ +#define DARWINN_PORT_DEFAULT_MEMORY_BARRIERS_H_ + +#define barrier() asm volatile("" ::: "memory") +#define isb(option) asm volatile("" ::: "memory") +#define dsb(option) asm volatile("" ::: "memory") +#define dmb(option) asm volatile("" ::: "memory") +#define dfb(option) asm volatile("" ::: "memory") +#define rmb() asm volatile("" ::: "memory") +#define wmb() asm volatile("" ::: "memory") + +#endif // DARWINN_PORT_DEFAULT_MEMORY_BARRIERS_H_ diff --git a/port/default/mutex.h b/port/default/mutex.h new file mode 100644 index 0000000..4e4a80f --- /dev/null +++ b/port/default/mutex.h @@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_MUTEX_H_ +#define DARWINN_PORT_DEFAULT_MUTEX_H_ + +#include //NOLINT + +#include "port/default/thread_annotations.h" + +namespace platforms { +namespace darwinn { + +class LOCKABLE Mutex { + public: + Mutex() = default; + ~Mutex() = default; + + void Lock() EXCLUSIVE_LOCK_FUNCTION() { mutex_.lock(); } + + void Unlock() UNLOCK_FUNCTION() { mutex_.unlock(); } + + private: + std::mutex mutex_; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_MUTEX_H_ diff --git a/port/default/port_from_tf/BUILD b/port/default/port_from_tf/BUILD new file mode 100644 index 0000000..464655e --- /dev/null +++ b/port/default/port_from_tf/BUILD @@ -0,0 +1,69 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# Port of various google3 libraries and utilities. + +package(default_visibility = ["//visibility:public"]) + +# Files are originally ported from TensorFlow. +licenses(["notice"]) + +exports_files(["logging.h"]) + +# Test helpers +cc_library( + name = "status_test_util", + testonly = 1, + hdrs = [ + "status_test_util.h", + ], + deps = [ + ":port_from_tf", + "//port/default:port", + ], +) + +# Port of various google3 libraries and utilities. +cc_library( + name = "port_from_tf", + srcs = [ + "logging.cc", + "status.cc", + "statusor.cc", + ], + hdrs = [ + "array_slice.h", + "array_slice_internal.h", + "casts.h", + "errors.h", + "logging.h", + "macros.h", + "math_util.h", + "ptr_util.h", + "status.h", + "statusor.h", + "statusor_internals.h", + ], + deps = [ + "//port/default:port_base", + ], +) + +cc_library( + name = "thread_annotations", + hdrs = [ + "thread_annotations.h", + ], +) diff --git a/port/default/port_from_tf/LICENSE b/port/default/port_from_tf/LICENSE new file mode 100644 index 0000000..146d9b7 --- /dev/null +++ b/port/default/port_from_tf/LICENSE @@ -0,0 +1,203 @@ +Copyright 2018 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/port/default/port_from_tf/array_slice.h b/port/default/port_from_tf/array_slice.h new file mode 100644 index 0000000..f8a8616 --- /dev/null +++ b/port/default/port_from_tf/array_slice.h @@ -0,0 +1,307 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An ArraySlice represents an immutable array of elements of type +// T. It has a length "length", and a base pointer "ptr", and the +// array it represents contains the elements "ptr[0] .. ptr[len-1]". +// The backing store for the array is *not* owned by the ArraySlice +// object, and clients must arrange for the backing store to remain +// live while the ArraySlice object is in use. +// +// An ArraySlice is somewhat analogous to a StringPiece, but for +// array elements of type T. +// +// Implicit conversion operations are provided from types such as +// std::vector and util::gtl::InlinedVector. Note that ArraySlice +// objects constructed from types in this way may be invalidated by +// any operations that mutate the underlying vector. +// +// One common use for ArraySlice is when passing arguments to a +// routine where you want to be able to accept a variety of array +// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array, +// etc.). The usual approach here is to have the client explicitly +// pass in a pointer and a length, as in: +// +// void MyRoutine(const int* elems, int N) { +// for (int i = 0; i < N; i++) { .. do something with elems[i] .. } +// } +// +// Unfortunately, this leads to ugly and error-prone code at the call site: +// +// std::vector my_vector; +// MyRoutine(vector_as_array(&my_vector), my_vector.size()); +// +// util::gtl::InlinedVector my_inline_vector; +// MyRoutine(my_inline_vector.array(), my_inline_vector.size()); +// +// int my_array[10]; +// MyRoutine(my_array, 10); +// +// Instead, you can use an ArraySlice as the argument to the routine: +// +// void MyRoutine(ArraySlice a) { +// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. } +// } +// +// This makes the call sites cleaner, for the most part: +// +// std::vector my_vector; +// MyRoutine(my_vector); +// +// util::gtl::InlinedVector my_inline_vector; +// MyRoutine(my_inline_vector); +// +// int my_array[10]; +// MyRoutine(my_array); +// +// int* my_array = new int[10]; +// MyRoutine(gtl::ArraySlice(my_array, 10)); +// +// MutableArraySlice represents a mutable array of elements, and, like +// ArraySlice, does not own the backing store. The implicit constructors it +// provides allow functions not to worry about whether their mutable arguments +// refer to vectors, arrays, proto2::RepeatedFields, etc.: +// +// void MyMutatingRoutine(MutableArraySlice a) { +// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. } +// } +// +// std::vector my_vector; +// MyMutatingRoutine(&my_vector); +// +// int my_array[10]; +// MyMutatingRoutine(my_array); +// +// int* my_array = new int[10]; +// MyMutatingRoutine(gtl::MutableArraySlice(my_array, 10)); +// +// MyProto my_proto; +// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); } +// MyMutatingRoutine(my_proto.mutable_value()); + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_H_ + +#include +#include +#include +#include + +#include "port/default/port_from_tf/array_slice_internal.h" + +namespace platforms { +namespace darwinn { +namespace gtl { + +template +class ArraySlice { + private: + typedef array_slice_internal::ArraySliceImpl Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + ArraySlice() : impl_(nullptr, 0) {} + ArraySlice(const_pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + ArraySlice(const std::vector& v) // NOLINT(runtime/explicit) + : impl_(v.data(), v.size()) {} + + template + ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + // The constructor for any class supplying 'data() const' that returns either + // const T* or a less const-qualified version of it, and 'some_integral_type + // size() const'. proto2::RepeatedField, string and (since C++11) + // std::vector and std::array are examples of this. See + // array_slice_internal.h for details. + template > + ArraySlice(const V& v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Implicitly constructs an ArraySlice from an initializer list. This makes it + // possible to pass a brace-enclosed initializer list to a function expecting + // an ArraySlice: + // void Process(ArraySlice x); + // Process({1, 2, 3}); + // The data referenced by the initializer_list must outlive this + // ArraySlice. For example, "ArraySlice s={1,2};" and "return + // ArraySlice({3,4});" are errors, as the resulting ArraySlice may + // reference data that is no longer valid. + ArraySlice(std::initializer_list v) // NOLINT(runtime/explicit) + : impl_(v.begin(), v.size()) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + ArraySlice(const ArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + const_pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + const_reference operator[](size_type i) const { return impl_[i]; } + const_reference at(size_type i) const { return impl_.at(i); } + const_reference front() const { return impl_.front(); } + const_reference back() const { return impl_.back(); } + + const_iterator begin() const { return impl_.begin(); } + const_iterator end() const { return impl_.end(); } + const_reverse_iterator rbegin() const { return impl_.rbegin(); } + const_reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + // These relational operators have the same semantics as the + // std::vector relational operators: they do deep (elementwise) + // comparisons. Array slices are equal iff their size is the same + // and all their elements are equal. + bool operator==(ArraySlice other) const { return impl_ == other.impl_; } + bool operator!=(ArraySlice other) const { return impl_ != other.impl_; } + + private: + Impl impl_; +}; + +// Mutable version of ArraySlice, which allows the clients to mutate the +// underlying data. It is implicitly convertible to ArraySlice since it provides +// the data() and size() methods with correct signatures. When a +// MutableArraySlice is created from a pointer to a container (as opposed to raw +// memory pointer), the pointer must not be null. +// +// A note on const-ness: "mutable" here refers to the mutability of the +// underlying data, not of the slice itself. It is perfectly reasonable to have +// a variable of type "const MutableArraySlice"; this means that the bounds +// of the view on the array cannot be changed, but the underlying data in the +// array still may be modified. This is akin to a "T* const" pointer, as opposed +// to a "const T*" pointer (corresponding to a non-const ArraySlice). +template +class MutableArraySlice { + private: + typedef array_slice_internal::MutableArraySliceImpl Impl; + + public: + typedef T value_type; + typedef typename Impl::pointer pointer; + typedef typename Impl::const_pointer const_pointer; + typedef typename Impl::reference reference; + typedef typename Impl::const_reference const_reference; + typedef typename Impl::iterator iterator; + typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::reverse_iterator reverse_iterator; + typedef typename Impl::const_reverse_iterator const_reverse_iterator; + typedef typename Impl::size_type size_type; + typedef typename Impl::difference_type difference_type; + + static const size_type npos = Impl::npos; + + MutableArraySlice() : impl_(nullptr, 0) {} + MutableArraySlice(pointer array, size_type length) : impl_(array, length) {} + + // Implicit conversion constructors + MutableArraySlice(std::vector* v) // NOLINT(runtime/explicit) + : impl_(v->data(), v->size()) {} + + template + MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit) + : impl_(a, N) {} + + // The constructor for any class supplying 'T* data()' or 'T* mutable_data()' + // (the former is called if both exist), and 'some_integral_type size() + // const'. proto2::RepeatedField is an example of this. Also supports string + // arguments, when T==char. The appropriate ctor is selected using SFINAE. See + // array_slice_internal.h for details. + template > + MutableArraySlice(V* v) // NOLINT(runtime/explicit) + : impl_(v) {} + + // Substring of another MutableArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + // If len==npos, the substring continues till the end of x. + MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len) + : impl_(x.impl_, pos, len) {} + + // Accessors. + pointer data() const { return impl_.data(); } + size_type size() const { return impl_.size(); } + size_type length() const { return size(); } + bool empty() const { return size() == 0; } + + void clear() { impl_.clear(); } + + reference operator[](size_type i) const { return impl_[i]; } + reference at(size_type i) const { return impl_.at(i); } + reference front() const { return impl_.front(); } + reference back() const { return impl_.back(); } + + iterator begin() const { return impl_.begin(); } + iterator end() const { return impl_.end(); } + reverse_iterator rbegin() const { return impl_.rbegin(); } + reverse_iterator rend() const { return impl_.rend(); } + + void remove_prefix(size_type n) { impl_.remove_prefix(n); } + void remove_suffix(size_type n) { impl_.remove_suffix(n); } + void pop_back() { remove_suffix(1); } + void pop_front() { remove_prefix(1); } + + bool operator==(ArraySlice other) const { + return ArraySlice(*this) == other; + } + bool operator!=(ArraySlice other) const { + return ArraySlice(*this) != other; + } + + // DEPRECATED(jacobsa): Please use data() instead. + pointer mutable_data() const { return impl_.data(); } + + private: + Impl impl_; +}; + +template +const typename ArraySlice::size_type ArraySlice::npos; +template +const typename MutableArraySlice::size_type MutableArraySlice::npos; + +} // namespace gtl +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_H_ diff --git a/port/default/port_from_tf/array_slice_internal.h b/port/default/port_from_tf/array_slice_internal.h new file mode 100644 index 0000000..eb456a6 --- /dev/null +++ b/port/default/port_from_tf/array_slice_internal.h @@ -0,0 +1,274 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by +// array_slice.h. + +// Helper functions and templates for ArraySlice. + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_INTERNAL_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_INTERNAL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "port/default/port_from_tf/logging.h" + +namespace platforms { +namespace darwinn { +namespace gtl { +namespace array_slice_internal { + +// Template logic for generic constructors. + +// Wrappers whose Get() delegates to the appropriate method of a container, and +// is defined when this method exists. Delegates to the const method if C is a +// const type. +struct Data { + template + static decltype(std::declval().data()) Get(C* v) { + return v->data(); + } +}; + +struct MutableData { + template + static decltype(std::declval().mutable_data()) Get(C* v) { + return v->mutable_data(); + } +}; + +struct Size { + template + static decltype(std::declval().size()) Get(C* v) { + return v->size(); + } +}; + +struct MutableStringData { + // Defined only for string. + static char* Get(std::string* v) { + return v->empty() ? nullptr : &*v->begin(); + } +}; + +// Checks whether M::Get(C*) is defined and has a return type R such that +// Checker::valid()==true. +template +struct HasGetHelper : public M { + private: + struct None {}; + // M::Get is selected when it is viable. Get(...) is selected otherwise. + using M::Get; + static None Get(...); + + public: + static constexpr bool HasGet() { + using Result = decltype(Get(std::declval())); + return !std::is_same() && Checker::template valid(); + } +}; + +// Defines HasGet() for a particular method, container, and checker. If +// HasGet()==true, provides Get() that delegates to the method. +template ::HasGet()> +struct Wrapper { + static constexpr bool HasGet() { return false; } +}; + +template +struct Wrapper { + static constexpr bool HasGet() { return true; } + static decltype(M::Get(std::declval())) Get(C* v) { return M::Get(v); } +}; + +// Type checker for a method returning an integral value. +struct SizeChecker { + template + static constexpr bool valid() { + return std::is_integral::value; + } +}; + +// Type checker for a method returning either a pointer to T or a less const +// version of that. +template +struct DataChecker { + // We want to enable conversion from std::vector to ArraySlice + // but + // disable conversion from std::vector to ArraySlice. Here we + // use + // the fact that U** is convertible to Q* const* if and only if Q is the same + // type or a more cv-qualified version of U. + template + static constexpr bool valid() { + return std::is_convertible::value; + } +}; + +// Aliases to A if A::HasGet()==true, or to B otherwise. +template +using FirstWithGet = typename std::conditional::type; + +// Wraps C::data() const, returning a pointer to const data. +template +using ContainerData = Wrapper, const C>; + +// Wraps a method returning a pointer to mutable data. Prefers data() over +// mutable_data(), and handles strings when T==char. If data() returns a pointer +// to mutable data, it is most likely overloaded, but may also be a single +// method 'T* C::data() const' in a non-STL-compliant container. +template +using ContainerMutableData = + FirstWithGet, C>, + FirstWithGet, C>, + Wrapper, C>>>; + +// Wraps C::size() const. +template +using ContainerSize = Wrapper; + +// Implementation class for ArraySlice and MutableArraySlice. In the case of +// ArraySlice, T will be a const type; for MutableArraySlice, T will be a +// mutable type. +template +class ArraySliceImplBase { + public: + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef pointer iterator; + typedef const_pointer const_iterator; + typedef std::reverse_iterator reverse_iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + static const size_type npos = static_cast(-1); + + ArraySliceImplBase(pointer array, size_type length) + : ptr_(array), length_(length) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len) + : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {} + + // Some of the const methods below return pointers and references to mutable + // data. This is only the case in this internal class; ArraySlice and + // MutableArraySlice provide deep-constness. + + pointer data() const { return ptr_; } + size_type size() const { return length_; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + reference operator[](size_type i) const { return ptr_[i]; } + reference at(size_type i) const { + DCHECK_LT(i, length_); + return ptr_[i]; + } + reference front() const { + DCHECK_GT(length_, 0); + return ptr_[0]; + } + reference back() const { + DCHECK_GT(length_, 0); + return ptr_[length_ - 1]; + } + + void remove_prefix(size_type n) { + DCHECK_GE(length_, n); + ptr_ += n; + length_ -= n; + } + void remove_suffix(size_type n) { + DCHECK_GE(length_, n); + length_ -= n; + } + + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + bool operator==(const ArraySliceImplBase& other) const { + if (size() != other.size()) return false; + if (data() == other.data()) return true; + return std::equal(data(), data() + size(), other.data()); + } + bool operator!=(const ArraySliceImplBase& other) const { + return !(*this == other); + } + + private: + pointer ptr_; + size_type length_; +}; + +template +class ArraySliceImpl : public ArraySliceImplBase { + public: + using ArraySliceImplBase::ArraySliceImplBase; + + // Defined iff the data and size accessors for the container C have been + // defined. + template + using EnableIfConvertibleFrom = + typename std::enable_if::HasGet() && + ContainerSize::HasGet()>::type; + + // Constructs from a container when EnableIfConvertibleFrom is + // defined. std::addressof handles types with overloaded operator&. + template + explicit ArraySliceImpl(const C& v) + : ArraySliceImplBase(ContainerData::Get(std::addressof(v)), + ContainerSize::Get(std::addressof(v))) {} +}; + +template +class MutableArraySliceImpl : public ArraySliceImplBase { + public: + using ArraySliceImplBase::ArraySliceImplBase; + + template + using EnableIfConvertibleFrom = + typename std::enable_if::HasGet() && + ContainerSize::HasGet()>::type; + + template + explicit MutableArraySliceImpl(C* v) + : ArraySliceImplBase(ContainerMutableData::Get(v), + ContainerSize::Get(v)) {} +}; + +} // namespace array_slice_internal +} // namespace gtl +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_ARRAY_SLICE_INTERNAL_H_ diff --git a/port/default/port_from_tf/casts.h b/port/default/port_from_tf/casts.h new file mode 100644 index 0000000..9d4dda9 --- /dev/null +++ b/port/default/port_from_tf/casts.h @@ -0,0 +1,140 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Various Google-specific casting templates. +// +// This code is compiled directly on many platforms, including client +// platforms like Windows, Mac, and embedded systems. Before making +// any changes here, make sure that you're not breaking any platforms. +// + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_CASTS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_CASTS_H_ + +#include // for memcpy + +namespace platforms { +namespace darwinn { + +// bit_cast is a template function that implements the +// equivalent of "*reinterpret_cast(&source)". We need this in +// very low-level functions like the protobuf library and fast math +// support. +// +// float f = 3.14159265358979; +// int i = bit_cast(f); +// // i = 0x40490fdb +// +// The classical address-casting method is: +// +// // WRONG +// float f = 3.14159265358979; // WRONG +// int i = * reinterpret_cast(&f); // WRONG +// +// The address-casting method actually produces undefined behavior +// according to ISO C++ specification section 3.10 -15 -. Roughly, this +// section says: if an object in memory has one type, and a program +// accesses it with a different type, then the result is undefined +// behavior for most values of "different type". +// +// This is true for any cast syntax, either *(int*)&f or +// *reinterpret_cast(&f). And it is particularly true for +// conversions between integral lvalues and floating-point lvalues. +// +// The purpose of 3.10 -15- is to allow optimizing compilers to assume +// that expressions with different types refer to different memory. gcc +// 4.0.1 has an optimizer that takes advantage of this. So a +// non-conforming program quietly produces wildly incorrect output. +// +// The problem is not the use of reinterpret_cast. The problem is type +// punning: holding an object in memory of one type and reading its bits +// back using a different type. +// +// The C++ standard is more subtle and complex than this, but that +// is the basic idea. +// +// Anyways ... +// +// bit_cast<> calls memcpy() which is blessed by the standard, +// especially by the example in section 3.9 . Also, of course, +// bit_cast<> wraps up the nasty logic in one place. +// +// Fortunately memcpy() is very fast. In optimized mode, with a +// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline +// code with the minimal amount of data movement. On a 32-bit system, +// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8) +// compiles to two loads and two stores. +// +// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1. +// +// WARNING: if Dest or Source is a non-POD type, the result of the memcpy +// is likely to surprise you. +// +// Props to Bill Gibbons for the compile time assertion technique and +// Art Komninos and Igor Tandetnik for the msvc experiments. +// +// -- mec 2005-10-17 + +template +inline Dest bit_cast(const Source& source) { + static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match"); + + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +// Identity metafunction. +// DEPRECATED(b/10657301): No std equivalent, but it's trivial. +// NOTE: This will be released as part of ABCL because there are some sensible +// usage. However, this will be moved along with other template stuff from GTL +// into a separate file, which resides in a directory for MPL. +// Also, this will be moved into another namespace to avoid confusion with +// other similar things like libstdc++ __gnu_cxx::identity. +template +struct identity_ { + typedef T type; +}; + +// Use implicit_cast as a safe version of static_cast or const_cast +// for implicit conversions. For example: +// - Upcasting in a type hierarchy. +// - Performing arithmetic conversions (int32 to int64, int to double, etc.). +// - Adding const or volatile qualifiers. +// +// In general, implicit_cast can be used to convert this code +// To to = from; +// DoSomething(to); +// to this +// DoSomething(implicit_cast(from)); +// +// base::identity_ is used to make a non-deduced context, which +// forces all callers to explicitly specify the template argument. +template +inline To implicit_cast(typename identity_::type to) { + return to; +} + +// This version of implicit_cast is used when two template arguments +// are specified. It's obsolete and should not be used. +template +inline To implicit_cast(typename identity_::type const& f) { + return f; +} + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_CASTS_H_ diff --git a/port/default/port_from_tf/errors.h b/port/default/port_from_tf/errors.h new file mode 100644 index 0000000..9a538c2 --- /dev/null +++ b/port/default/port_from_tf/errors.h @@ -0,0 +1,81 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_ERRORS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_ERRORS_H_ + +#include "port/default/error_codes.h" +#include "port/default/port_from_tf/logging.h" +#include "port/default/port_from_tf/macros.h" +#include "port/default/port_from_tf/status.h" +#include "port/default/strcat.h" + +namespace platforms { +namespace darwinn { +namespace util { + +typedef error::Code Code; + +// Append some context to an error message. Each time we append +// context put it on a new line, since it is possible for there +// to be several layers of additional context. +template +void AppendToMessage(Status* status, Args... args) { + *status = + Status(status->code(), StrCat(status->error_message(), "\n\t", args...)); +} + +// Convenience functions for generating and using error status. +// Example usage: +// status.Update(error::InvalidArgumentError("The ", foo, " isn't right.")); +// if (error::IsInvalidArgument(status)) { ... } +// switch (status.code()) { case error::INVALID_ARGUMENT: ... } + +#define DECLARE_ERROR(FUNC, CONST) \ + template \ + Status FUNC##Error(Args... args) { \ + return Status(::platforms::darwinn::util::error::CONST, StrCat(args...)); \ + } \ + inline bool Is##FUNC(const Status& status) { \ + return status.code() == ::platforms::darwinn::util::error::CONST; \ + } + +DECLARE_ERROR(Cancelled, CANCELLED) +DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT) +DECLARE_ERROR(NotFound, NOT_FOUND) +DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS) +DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED) +DECLARE_ERROR(Unavailable, UNAVAILABLE) +DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION) +DECLARE_ERROR(OutOfRange, OUT_OF_RANGE) +DECLARE_ERROR(Unimplemented, UNIMPLEMENTED) +DECLARE_ERROR(Internal, INTERNAL) +DECLARE_ERROR(Aborted, ABORTED) +DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED) +DECLARE_ERROR(DataLoss, DATA_LOSS) +DECLARE_ERROR(Unknown, UNKNOWN) +DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED) +DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED) + +#undef DECLARE_ERROR + +// The CanonicalCode() for non-errors. +using ::platforms::darwinn::util::error::OK; + +} // namespace util +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_ERRORS_H_ diff --git a/port/default/port_from_tf/logging.cc b/port/default/port_from_tf/logging.cc new file mode 100644 index 0000000..cedfaa3 --- /dev/null +++ b/port/default/port_from_tf/logging.cc @@ -0,0 +1,180 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "port/default/port_from_tf/logging.h" + +// When building as part of Android System, equivalent logging macros are +// already available. +#if !defined(DARWINN_PORT_ANDROID_SYSTEM) + +#include "port/default/port_from_tf/macros.h" + +#if defined(ANDROID) || defined(__ANDROID__) +#include +#include +#include +#endif + +#include +#include + +namespace platforms { +namespace darwinn { +namespace internal { + +LogMessage::LogMessage(const char* fname, int line, int severity) + : fname_(fname), line_(line), severity_(severity) {} + +#if defined(ANDROID) || defined(__ANDROID__) + +void LogMessage::GenerateLogMessage() { + int android_log_level; + switch (severity_) { + case INFO: + android_log_level = ANDROID_LOG_INFO; + break; + case WARNING: + android_log_level = ANDROID_LOG_WARN; + break; + case ERROR: + android_log_level = ANDROID_LOG_ERROR; + break; + case FATAL: + android_log_level = ANDROID_LOG_FATAL; + break; + default: + if (severity_ < INFO) { + android_log_level = ANDROID_LOG_VERBOSE; + } else { + android_log_level = ANDROID_LOG_ERROR; + } + break; + } + + std::stringstream ss; + const char* const partial_name = strrchr(fname_, '/'); + ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_ + << " " << str(); + __android_log_write(android_log_level, "native", ss.str().c_str()); + + // Also log to stderr (for standalone Android apps). + std::cerr << "native : " << ss.str() << std::endl; + + // Android logging at level FATAL does not terminate execution, so abort() + // is still required to stop the program. + if (severity_ == FATAL) { + std::abort(); + } +} + +#else + +void LogMessage::GenerateLogMessage() { + // TODO: For open source version, replace this with something + // that logs through the env or something and fill in appropriate time info. + fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_, + str().c_str()); +} + +#endif + +LogMessage::~LogMessage() { GenerateLogMessage(); } + +LogMessageFatal::LogMessageFatal(const char* file, int line) + : LogMessage(file, line, FATAL) {} +LogMessageFatal::~LogMessageFatal() { + // abort() ensures we don't return (we promised we would not via + // ATTRIBUTE_NORETURN). + GenerateLogMessage(); + std::abort(); +} + +void LogString(const char* fname, int line, int severity, + const std::string& message) { + LogMessage(fname, line, severity) << message; +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "char value " << (short)v; + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "signed char value " << (short)v; + } +} + +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v) { + if (v >= 32 && v <= 126) { + (*os) << "'" << v << "'"; + } else { + (*os) << "unsigned char value " << (unsigned short)v; + } +} + +#if LANG_CXX11 +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p) { + (*os) << "nullptr"; +} +#endif + +CheckOpMessageBuilder::CheckOpMessageBuilder(const char* exprtext) + : stream_(new std::ostringstream) { + *stream_ << "Check failed: " << exprtext << " ("; +} + +CheckOpMessageBuilder::~CheckOpMessageBuilder() { delete stream_; } + +std::ostream* CheckOpMessageBuilder::ForVar2() { + *stream_ << " vs. "; + return stream_; +} + +std::string* CheckOpMessageBuilder::NewString() { + *stream_ << ")"; + return new std::string(stream_->str()); +} + +} // namespace internal +} // namespace darwinn +} // namespace platforms + +#endif // !DARWINN_PORT_ANDROID_SYSTEM + +namespace platforms { +namespace darwinn { +namespace internal { +namespace { +// TODO: can we make the logging level somehow local to a request? +int log_level = 0; +} // namespace + +void SetLoggingLevel(int level) { log_level = level; } + +int GetLoggingLevel() { return log_level; } + +} // namespace internal +} // namespace darwinn +} // namespace platforms diff --git a/port/default/port_from_tf/logging.h b/port/default/port_from_tf/logging.h new file mode 100644 index 0000000..d654be7 --- /dev/null +++ b/port/default/port_from_tf/logging.h @@ -0,0 +1,333 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_LOGGING_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_LOGGING_H_ + +#if defined(DARWINN_PORT_ANDROID_SYSTEM) + +#ifndef LOG_TAG +#define LOG_TAG "Darwinn" +#endif + +// When building as part of Android System, equivalent logging macros are +// already available. +#include "android-base/logging.h" + +// C++ 14 (current Android runtime libc++) does not have overloads for +// std::nullptr_t. Add this as a work around until Android runtime libc++ +// supports C++ 17. +static inline std::ostream& operator<<(std::ostream& s, std::nullptr_t) { + return s << "nullptr"; +} + +#else // !DARWINN_PORT_ANDROID_SYSTEM + +#include +#include +#include +#include +#include +#include + +#include "port/default/port_from_tf/macros.h" + +namespace platforms { +namespace darwinn { + +const int INFO = 0; // base_logging::INFO; +const int WARNING = 1; // base_logging::WARNING; +const int ERROR = 2; // base_logging::ERROR; +const int FATAL = 3; // base_logging::FATAL; +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES; + +namespace internal { + +class LogMessage : public std::basic_ostringstream { + public: + LogMessage(const char* fname, int line, int severity); + ~LogMessage(); + + protected: + void GenerateLogMessage(); + + private: + const char* fname_; + int line_; + int severity_; +}; + +// LogMessageFatal ensures the process will exit in failure after +// logging this message. +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) ATTRIBUTE_COLD; + ATTRIBUTE_NORETURN ~LogMessageFatal(); +}; + +#define _LOG_INFO \ + ::platforms::darwinn::internal::LogMessage(__FILE__, __LINE__, \ + platforms::darwinn::INFO) +#define _LOG_WARNING \ + ::platforms::darwinn::internal::LogMessage(__FILE__, __LINE__, \ + platforms::darwinn::WARNING) +#define _LOG_ERROR \ + ::platforms::darwinn::internal::LogMessage(__FILE__, __LINE__, \ + platforms::darwinn::ERROR) +#define _LOG_FATAL \ + ::platforms::darwinn::internal::LogMessageFatal(__FILE__, __LINE__) + +#define LOG(severity) _LOG_##severity + +#define _LOG_QFATAL _LOG_FATAL + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by NDEBUG, so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + if (PREDICT_FALSE(!(condition))) LOG(FATAL) << "Check failed: " #condition " " + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template +inline const T& GetReferenceableValue(const T& t) { + return t; +} +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline short GetReferenceableValue(short t) { return t; } +inline unsigned short GetReferenceableValue(unsigned short t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline long GetReferenceableValue(long t) { return t; } +inline unsigned long GetReferenceableValue(unsigned long t) { return t; } +inline long long GetReferenceableValue(long long t) { return t; } +inline unsigned long long GetReferenceableValue(unsigned long long t) { + return t; +} + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream* os, const T& v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> +void MakeCheckOpValueString(std::ostream* os, const char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const signed char& v); +template <> +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); + +#if LANG_CXX11 +// We need an explicit specialization for std::nullptr_t. +template <> +void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p); +#endif + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is non-NULL. +struct CheckOpString { + CheckOpString(std::string* str) : str_(str) {} + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + operator bool() const { return PREDICT_FALSE(str_ != NULL); } + std::string* str_; +}; + +// Build the error message string. Specify no inlining for code size. +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, + const char* exprtext) ATTRIBUTE_NOINLINE; + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class CheckOpMessageBuilder { + public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char* exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream* ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream* ForVar2(); + // Get the result (inserts the closing ")"). + std::string* NewString(); + + private: + std::ostringstream* stream_; +}; + +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, + const char* exprtext) { + CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// The (int, int) specialization works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +// The (size_t, int) and (int, size_t) specialization are to handle unsigned +// comparison errors while still being thorough with the comparison. +#define DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline std::string* name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (PREDICT_TRUE(v1 op v2)) \ + return NULL; \ + else \ + return ::platforms::darwinn::internal::MakeCheckOpString(v1, v2, \ + exprtext); \ + } \ + inline std::string* name##Impl(int v1, int v2, const char* exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } \ + inline std::string* name##Impl(const size_t v1, const int v2, \ + const char* exprtext) { \ + if (PREDICT_FALSE(v2 < 0)) { \ + return ::platforms::darwinn::internal::MakeCheckOpString(v1, v2, \ + exprtext); \ + } \ + const size_t uval = (size_t)((unsigned)v1); \ + return name##Impl(uval, v2, exprtext); \ + } \ + inline std::string* name##Impl(const int v1, const size_t v2, \ + const char* exprtext) { \ + if (PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ + return ::platforms::darwinn::internal::MakeCheckOpString(v1, v2, \ + exprtext); \ + } \ + const size_t uval = (size_t)((unsigned)v2); \ + return name##Impl(v1, uval, exprtext); \ + } + +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// base/logging.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +DEFINE_CHECK_OP_IMPL(Check_EQ, ==) // Compilation error with CHECK_EQ(NULL, x)? +DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +DEFINE_CHECK_OP_IMPL(Check_LE, <=) +DEFINE_CHECK_OP_IMPL(Check_LT, <) +DEFINE_CHECK_OP_IMPL(Check_GE, >=) +DEFINE_CHECK_OP_IMPL(Check_GT, >) +#undef DEFINE_CHECK_OP_IMPL + +// In optimized mode, use CheckOpString to hint to compiler that +// the while condition is unlikely. +#define CHECK_OP_LOG(name, op, val1, val2) \ + while (::platforms::darwinn::internal::CheckOpString _result = \ + ::platforms::darwinn::internal::name##Impl( \ + ::platforms::darwinn::internal::GetReferenceableValue(val1), \ + ::platforms::darwinn::internal::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + ::platforms::darwinn::internal::LogMessageFatal(__FILE__, __LINE__) \ + << *(_result.str_) + +#define CHECK_OP(name, op, val1, val2) CHECK_OP_LOG(name, op, val1, val2) + +// CHECK_EQ/NE/... +#define CHECK_EQ(val1, val2) CHECK_OP(Check_EQ, ==, val1, val2) +#define CHECK_NE(val1, val2) CHECK_OP(Check_NE, !=, val1, val2) +#define CHECK_LE(val1, val2) CHECK_OP(Check_LE, <=, val1, val2) +#define CHECK_LT(val1, val2) CHECK_OP(Check_LT, <, val1, val2) +#define CHECK_GE(val1, val2) CHECK_OP(Check_GE, >=, val1, val2) +#define CHECK_GT(val1, val2) CHECK_OP(Check_GT, >, val1, val2) + +#ifndef NDEBUG +// DCHECK_EQ/NE/... +#define DCHECK(condition) CHECK(condition) +#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) + +#else + +#define DCHECK(condition) \ + while (false && (condition)) LOG(FATAL) + +// NDEBUG is defined, so DCHECK_EQ(x, y) and so on do nothing. +// However, we still want the compiler to parse x and y, because +// we don't want to lose potentially useful errors and warnings. +// _DCHECK_NOP is a helper, and should not be used outside of this file. +#define _DCHECK_NOP(x, y) \ + while (false && ((void)(x), (void)(y), 0)) LOG(FATAL) + +#define DCHECK_EQ(x, y) _DCHECK_NOP(x, y) +#define DCHECK_NE(x, y) _DCHECK_NOP(x, y) +#define DCHECK_LE(x, y) _DCHECK_NOP(x, y) +#define DCHECK_LT(x, y) _DCHECK_NOP(x, y) +#define DCHECK_GE(x, y) _DCHECK_NOP(x, y) +#define DCHECK_GT(x, y) _DCHECK_NOP(x, y) + +#endif + +// These are for when you don't want a CHECK failure to print a verbose +// stack trace. The implementation of CHECK* in this file already doesn't. +#define QCHECK(condition) CHECK(condition) +#define QCHECK_EQ(x, y) CHECK_EQ(x, y) +#define QCHECK_NE(x, y) CHECK_NE(x, y) +#define QCHECK_LE(x, y) CHECK_LE(x, y) +#define QCHECK_LT(x, y) CHECK_LT(x, y) +#define QCHECK_GE(x, y) CHECK_GE(x, y) +#define QCHECK_GT(x, y) CHECK_GT(x, y) + +} // namespace internal +} // namespace darwinn +} // namespace platforms + +#endif // !DARWINN_PORT_ANDROID_SYSTEM + +// Following macros are shared between default and Android. +namespace platforms { +namespace darwinn { +namespace internal { + +// There is no system VLOG in Android runtime. Provide one here. +#undef VLOG_IS_ON +extern int GetLoggingLevel(); +extern void SetLoggingLevel(int new_logging_level); +#define VLOG_IS_ON(lvl) \ + (lvl <= ::platforms::darwinn::internal::GetLoggingLevel()) + +#ifndef VLOG +#define VLOG(lvl) \ + if (VLOG_IS_ON(lvl)) LOG(INFO) +#endif + +} // namespace internal +} // namespace darwinn +} // namespace platforms + + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_LOGGING_H_ diff --git a/port/default/port_from_tf/macros.h b/port/default/port_from_tf/macros.h new file mode 100644 index 0000000..11bd19c --- /dev/null +++ b/port/default/port_from_tf/macros.h @@ -0,0 +1,140 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_MACROS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_MACROS_H_ + +// When building as part of Android System, many of these macros are +// already available through Android. +#if defined(DARWINN_PORT_ANDROID_SYSTEM) + +#include "android-base/macros.h" + +// Compiler attributes +#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) +// Compiler supports GCC-style attributes +#define ATTRIBUTE_NORETURN __attribute__((noreturn)) +#define ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define ATTRIBUTE_COLD __attribute__((cold)) +#define ATTRIBUTE_WEAK __attribute__((weak)) +#define ATTRIBUTE_PACKED __attribute__((packed)) +#define ABSL_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define PRINTF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +#define SCANF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__scanf__, string_index, first_to_check))) +#else +#error "Not intended to be compiled by compilers other than clang/gcc" +#endif + +// GCC can be told that a certain branch is not likely to be taken (for +// instance, a CHECK failure), and use that information in static analysis. +// Giving it this information can help it optimize for the common case in +// the absence of better information (ie. -fprofile-arcs). +#if defined(COMPILER_GCC3) +#define PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define PREDICT_FALSE(x) (x) +#define PREDICT_TRUE(x) (x) +#endif + +#else // !DARWINN_PORT_ANDROID_SYSTEM + +// Compiler attributes +#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) +// Compiler supports GCC-style attributes +#define ATTRIBUTE_NORETURN __attribute__((noreturn)) +#define ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define ATTRIBUTE_UNUSED __attribute__((unused)) +#define ATTRIBUTE_COLD __attribute__((cold)) +#define ATTRIBUTE_WEAK __attribute__((weak)) +#define ATTRIBUTE_PACKED __attribute__((packed)) +#define ABSL_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define PRINTF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +#define SCANF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__scanf__, string_index, first_to_check))) +#elif defined(COMPILER_MSVC) +// Non-GCC equivalents +#define ATTRIBUTE_NORETURN __declspec(noreturn) +#define ATTRIBUTE_NOINLINE +#define ATTRIBUTE_UNUSED +#define ATTRIBUTE_COLD +#define ABSL_MUST_USE_RESULT +#define ATTRIBUTE_PACKED +#define PRINTF_ATTRIBUTE(string_index, first_to_check) +#define SCANF_ATTRIBUTE(string_index, first_to_check) +#else +// Non-GCC equivalents +#define ATTRIBUTE_NORETURN +#define ATTRIBUTE_NOINLINE +#define ATTRIBUTE_UNUSED +#define ATTRIBUTE_COLD +#define ATTRIBUTE_WEAK +#define ABSL_MUST_USE_RESULT +#define ATTRIBUTE_PACKED +#define PRINTF_ATTRIBUTE(string_index, first_to_check) +#define SCANF_ATTRIBUTE(string_index, first_to_check) +#endif + +// GCC can be told that a certain branch is not likely to be taken (for +// instance, a CHECK failure), and use that information in static analysis. +// Giving it this information can help it optimize for the common case in +// the absence of better information (ie. -fprofile-arcs). +#if defined(COMPILER_GCC3) +#define PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define PREDICT_FALSE(x) (x) +#define PREDICT_TRUE(x) (x) +#endif + +// A macro to disallow the copy constructor and operator= functions +// This is usually placed in the private: declarations for a class. +#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +// The ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// +// The expression ARRAYSIZE(a) is a compile-time constant of type +// size_t. +#define ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900) +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#endif + +#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) +#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") +#define FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT +#endif +#endif + +#ifndef FALLTHROUGH_INTENDED +#define FALLTHROUGH_INTENDED \ + do { \ + } while (0) +#endif + +#endif // !DARWINN_PORT_ANDROID_SYSTEM +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_MACROS_H_ diff --git a/port/default/port_from_tf/math_util.h b/port/default/port_from_tf/math_util.h new file mode 100644 index 0000000..e31d876 --- /dev/null +++ b/port/default/port_from_tf/math_util.h @@ -0,0 +1,122 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_MATH_UTIL_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_MATH_UTIL_H_ + +#include +#include + +#include "port/default/port_from_tf/logging.h" + +namespace platforms { +namespace darwinn { + +class MathUtil { + public: + // ---------------------------------------------------------------------- + // CeilOfRatio + // FloorOfRatio + // Returns the ceil (resp. floor) of the ratio of two integers. + // + // * IntegralType: any integral type, whether signed or not. + // * numerator: any integer: positive, negative, or zero. + // * denominator: a non-zero integer, positive or negative. + // + // This implementation is correct, meaning there is never any precision loss, + // and there is never an overflow. However, if the type is signed, having + // numerator == MathLimits::kMin and denominator == -1 is not a + // valid input, because kMin has a greater absolute value than kMax. + // + // Input validity is DCHECKed. When not in debug mode, invalid inputs raise + // SIGFPE. + // + // This method has been designed and tested so that it should always be + // preferred to alternatives. Indeed, there exist popular recipes to compute + // the result, such as casting to double, but they are in general incorrect. + // In cases where an alternative technique is correct, performance measurement + // showed the provided implementation is faster. + template + static IntegralType CeilOfRatio(IntegralType numerator, + IntegralType denominator) { + return CeilOrFloorOfRatio(numerator, denominator); + } + template + static IntegralType FloorOfRatio(IntegralType numerator, + IntegralType denominator) { + return CeilOrFloorOfRatio(numerator, denominator); + } + + template + static IntegralType CeilOrFloorOfRatio(IntegralType numerator, + IntegralType denominator); +}; + +// ---- CeilOrFloorOfRatio ---- +// This is a branching-free, cast-to-double-free implementation. +// +// Casting to double is in general incorrect because of loss of precision +// when casting an int64 into a double. +// +// There's a bunch of 'recipes' to compute a integer ceil (or floor) on the web, +// and most of them are incorrect. +template +IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator, + IntegralType denominator) { + static_assert(std::is_integral::value, + "CeilOrFloorOfRatio is only defined for integral types."); + DCHECK_NE(0, denominator) << "Division by zero is not supported."; + DCHECK(!std::is_signed::value || + numerator != std::numeric_limits::min() || + denominator != -1) + << "Dividing " << numerator << "by -1 is not supported: it would SIGFPE."; + + const IntegralType rounded_toward_zero = numerator / denominator; + const IntegralType intermediate_product = rounded_toward_zero * denominator; + + if (ceil) { // Compile-time condition: not an actual branching + // When rounded_toward_zero is negative, then an adjustment is never needed: + // the real ratio is negative, and so rounded toward zero is the ceil. + // When rounded_toward_zero is non-negative, an adjustment is needed if the + // sign of the difference numerator - intermediate_product is the same as + // the sign of the denominator. + // + // + // Using a bool and then a static_cast to IntegralType is not strictly + // necessary, but it makes the code clear, and anyway the compiler should + // get rid of it. + const bool needs_adjustment = + (rounded_toward_zero >= 0) && + ((denominator > 0 && numerator > intermediate_product) || + (denominator < 0 && numerator < intermediate_product)); + const IntegralType adjustment = static_cast(needs_adjustment); + const IntegralType ceil_of_ratio = rounded_toward_zero + adjustment; + return ceil_of_ratio; + } else { + // Floor case: symmetrical to the previous one + const bool needs_adjustment = + (rounded_toward_zero <= 0) && + ((denominator > 0 && numerator < intermediate_product) || + (denominator < 0 && numerator > intermediate_product)); + const IntegralType adjustment = static_cast(needs_adjustment); + const IntegralType floor_of_ratio = rounded_toward_zero - adjustment; + return floor_of_ratio; + } +} + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_MATH_UTIL_H_ diff --git a/port/default/port_from_tf/ptr_util.h b/port/default/port_from_tf/ptr_util.h new file mode 100644 index 0000000..ff6a35a --- /dev/null +++ b/port/default/port_from_tf/ptr_util.h @@ -0,0 +1,94 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_PTR_UTIL_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_PTR_UTIL_H_ + +#include + +namespace platforms { +namespace darwinn { +namespace gtl { + +// Trait to select overloads and return types for MakeUnique. +template +struct MakeUniqueResult { + using scalar = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using array = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using invalid = void; +}; + +// MakeUnique(...) is an early implementation of C++14 std::make_unique. +// It is designed to be 100% compatible with std::make_unique so that the +// eventual switchover will be a simple renaming operation. +template +typename MakeUniqueResult::scalar MakeUnique(Args&&... args) { // NOLINT + return std::unique_ptr( + new T(std::forward(args)...)); // NOLINT(build/c++11) +} + +// Overload for array of unknown bound. +// The allocation of arrays needs to use the array form of new, +// and cannot take element constructor arguments. +template +typename MakeUniqueResult::array MakeUnique(size_t n) { + return std::unique_ptr(new typename std::remove_extent::type[n]()); +} + +// Reject arrays of known bound. +template +typename MakeUniqueResult::invalid MakeUnique(Args&&... /* args */) = + delete; // NOLINT + +// ----------------------------------------------------------------------------- +// Function Template: WrapUnique() +// ----------------------------------------------------------------------------- +// +// Transfers ownership of a raw pointer to a `std::unique_ptr`. The returned +// value is a `std::unique_ptr` of deduced type. +// +// Example: +// X* NewX(int, int); +// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. +// +// `absl::WrapUnique` is useful for capturing the output of a raw pointer +// factory. However, prefer 'absl::MakeUnique(args...) over +// 'absl::WrapUnique(new T(args...))'. +// +// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. +// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. +// +// Note that `absl::WrapUnique(p)` is valid only if `delete p` is a valid +// expression. In particular, `absl::WrapUnique()` cannot wrap pointers to +// arrays, functions or void, and it must not be used to capture pointers +// obtained from array-new expressions (even though that would compile!). +template +std::unique_ptr WrapUnique(T* ptr) { + static_assert(!std::is_array::value, "array types are unsupported"); + static_assert(std::is_object::value, "non-object types are unsupported"); + return std::unique_ptr(ptr); +} + +} // namespace gtl +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_PTR_UTIL_H_ diff --git a/port/default/port_from_tf/status.cc b/port/default/port_from_tf/status.cc new file mode 100644 index 0000000..afb33bc --- /dev/null +++ b/port/default/port_from_tf/status.cc @@ -0,0 +1,141 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "port/default/port_from_tf/status.h" + +#include +#include + +#include "port/default/error_codes.h" + +namespace platforms { +namespace darwinn { +namespace util { + +Status::Status(error::Code code, const std::string& msg) { + assert(code != error::OK); + state_ = std::unique_ptr(new State); + state_->code = code; + state_->msg = msg; +} + +void Status::Update(const Status& new_status) { + if (ok()) { + *this = new_status; + } +} + +void Status::SlowCopyFrom(const State* src) { + if (src == nullptr) { + state_ = nullptr; + } else { + state_ = std::unique_ptr(new State(*src)); + } +} + +const std::string& Status::empty_string() { + static std::string* empty = new std::string; + return *empty; +} + +std::string Status::ToString() const { + if (state_ == nullptr) { + return "OK"; + } else { + char tmp[30]; + const char* type; + switch (code()) { + case error::CANCELLED: + type = "Cancelled"; + break; + case error::UNKNOWN: + type = "Unknown"; + break; + case error::INVALID_ARGUMENT: + type = "Invalid argument"; + break; + case error::DEADLINE_EXCEEDED: + type = "Deadline exceeded"; + break; + case error::NOT_FOUND: + type = "Not found"; + break; + case error::ALREADY_EXISTS: + type = "Already exists"; + break; + case error::PERMISSION_DENIED: + type = "Permission denied"; + break; + case error::UNAUTHENTICATED: + type = "Unauthenticated"; + break; + case error::RESOURCE_EXHAUSTED: + type = "Resource exhausted"; + break; + case error::FAILED_PRECONDITION: + type = "Failed precondition"; + break; + case error::ABORTED: + type = "Aborted"; + break; + case error::OUT_OF_RANGE: + type = "Out of range"; + break; + case error::UNIMPLEMENTED: + type = "Unimplemented"; + break; + case error::INTERNAL: + type = "Internal"; + break; + case error::UNAVAILABLE: + type = "Unavailable"; + break; + case error::DATA_LOSS: + type = "Data loss"; + break; + default: + snprintf(tmp, sizeof(tmp), "Unknown code(%d)", + static_cast(code())); + type = tmp; + break; + } + std::string result(type); + result += ": "; + result += state_->msg; + return result; + } +} + +void Status::IgnoreError() const { + // no-op +} + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +std::string* CheckOpHelperOutOfLine(const Status& v, const char* msg) { + std::string r("Non-OK-status: "); + r += msg; + r += " status: "; + r += v.ToString(); + // Leaks string but this is only to be used in a fatal error message + return new std::string(r); +} + +} // namespace util +} // namespace darwinn +} // namespace platforms diff --git a/port/default/port_from_tf/status.h b/port/default/port_from_tf/status.h new file mode 100644 index 0000000..fd58d0e --- /dev/null +++ b/port/default/port_from_tf/status.h @@ -0,0 +1,147 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUS_H_ + +#include +#include +#include +#include +#include + +#include "port/default/error_codes.h" +#include "port/default/port_from_tf/logging.h" + +namespace platforms { +namespace darwinn { +namespace util { + +/// Denotes success or failure of a call. +class Status { + public: + /// Create a success status. + Status() {} + + /// \brief Create a status with the specified error code and msg as a + /// human-readable string containing more detailed information. + Status(error::Code code, const std::string& msg); + + /// Copy the specified status. + Status(const Status& s); + void operator=(const Status& s); + + static Status OK() { return Status(); } + + /// Returns true iff the status indicates success. + bool ok() const { return (state_ == NULL); } + + error::Code code() const { return ok() ? error::OK : state_->code; } + error::Code CanonicalCode() const { return code(); } + + const std::string& error_message() const { + return ok() ? empty_string() : state_->msg; + } + const std::string& message() const { + return ok() ? empty_string() : state_->msg; + } + + bool operator==(const Status& x) const; + bool operator!=(const Status& x) const; + + /// \brief If `ok()`, stores `new_status` into `*this`. If `!ok()`, + /// preserves the current status, but may augment with additional + /// information about `new_status`. + /// + /// Convenient way of keeping track of the first error encountered. + /// Instead of: + /// `if (overall_status.ok()) overall_status = new_status` + /// Use: + /// `overall_status.Update(new_status);` + void Update(const Status& new_status); + + /// \brief Return a string representation of this status suitable for + /// printing. Returns the string `"OK"` for success. + std::string ToString() const; + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; + + private: + static const std::string& empty_string(); + struct State { + error::Code code; + std::string msg; + }; + // OK status has a `NULL` state_. Otherwise, `state_` points to + // a `State` structure containing the error code and message(s) + std::unique_ptr state_; + + void SlowCopyFrom(const State* src); +}; + +inline Status OkStatus() { return Status(); } + +inline Status::Status(const Status& s) + : state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {} + +inline void Status::operator=(const Status& s) { + // The following condition catches both aliasing (when this == &s), + // and the common case where both s and *this are ok. + if (state_ != s.state_) { + SlowCopyFrom(s.state_.get()); + } +} + +inline bool Status::operator==(const Status& x) const { + return (this->state_ == x.state_) || (ToString() == x.ToString()); +} + +inline bool Status::operator!=(const Status& x) const { return !(*this == x); } + +/// @ingroup core +std::ostream& operator<<(std::ostream& os, const Status& x); + +typedef std::function StatusCallback; + +extern std::string* CheckOpHelperOutOfLine(const Status& v, const char* msg); + +inline std::string* CheckOpHelper(Status v, const char* msg) { + if (v.ok()) return nullptr; + return CheckOpHelperOutOfLine(v, msg); +} + +#define DO_CHECK_OK(val, level) \ + while (auto _result = CheckOpHelper(val, #val)) LOG(level) << *(_result) + +#define CHECK_OK(val) DO_CHECK_OK(val, FATAL) +#define QCHECK_OK(val) DO_CHECK_OK(val, QFATAL) + +// DEBUG only version of CHECK_OK. Compiler still parses 'val' even in opt +// mode. +#ifndef NDEBUG +#define DCHECK_OK(val) CHECK_OK(val) +#else +#define DCHECK_OK(val) \ + while (false && (Status::OK() == (val))) LOG(FATAL) +#endif + +} // namespace util +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUS_H_ diff --git a/port/default/port_from_tf/statusor.cc b/port/default/port_from_tf/statusor.cc new file mode 100644 index 0000000..4d979db --- /dev/null +++ b/port/default/port_from_tf/statusor.cc @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "port/default/port_from_tf/statusor.h" + +#include "port/default/error_codes.h" +#include "port/default/port_from_tf/errors.h" +#include "port/default/port_from_tf/logging.h" +#include "port/default/port_from_tf/macros.h" +#include "port/default/port_from_tf/status.h" +#include "port/default/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace util { +namespace internal_statusor { + +void Helper::HandleInvalidStatusCtorArg(Status* status) { + const char* kMessage = + "An OK status is not a valid constructor argument to StatusOr"; + LOG(ERROR) << kMessage; + // Fall back to tensorflow::error::INTERNAL. + *status = InternalError(kMessage); +} + +void Helper::Crash(const Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; + unreachable(); // NOLINT +} + +} // namespace internal_statusor +} // namespace util +} // namespace darwinn +} // namespace platforms diff --git a/port/default/port_from_tf/statusor.h b/port/default/port_from_tf/statusor.h new file mode 100644 index 0000000..39be0db --- /dev/null +++ b/port/default/port_from_tf/statusor.h @@ -0,0 +1,311 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// StatusOr is the union of a Status object and a T +// object. StatusOr models the concept of an object that is either a +// usable value, or an error Status explaining why such a value is +// not present. To this end, StatusOr does not allow its Status +// value to be Status::OK. Furthermore, the value of a StatusOr +// must not be null. This is enforced by a debug check in most cases, +// but even when it is not, clients must not set the value to null. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_H_ + +#include "port/default/error_codes.h" +#include "port/default/port_from_tf/macros.h" +#include "port/default/port_from_tf/status.h" +#include "port/default/port_from_tf/statusor_internals.h" + +namespace platforms { +namespace darwinn { +namespace util { + +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class ABSL_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructible/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructible/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + // TODO: These should not participate in overload resolution if U + // is not convertible to T. + template + StatusOr(const StatusOr& other); + template + StatusOr(StatusOr&& other); + + // Conversion copy/move assignment operator, T must be convertible from U. + template + StatusOr& operator=(const StatusOr& other); + template + StatusOr& operator=(StatusOr&& other); + + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing error::INTERNAL as a fallback. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); + + // TODO: Add operator=(T) overloads. + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); + + // Returns this->status().ok() + bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK(). + const Status& status() const &; + Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const &; + T& ValueOrDie() &; + const T&& ValueOrDie() const &&; + T&& ValueOrDie() &&; + + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() : Base(Status(error::UNKNOWN, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const Status& status) : Base(status) {} + +template +StatusOr& StatusOr::operator=(const Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} + +template +StatusOr& StatusOr::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +template +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + +template +template +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; +} + +template +template +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + +template +template +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } + return *this; +} + +template +const Status& StatusOr::status() const & { + return this->status_; +} +template +Status StatusOr::status() && { + return ok() ? Status::OK() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const & { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +void StatusOr::IgnoreError() const { + // no-op +} + +} // namespace util +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_H_ diff --git a/port/default/port_from_tf/statusor_internals.h b/port/default/port_from_tf/statusor_internals.h new file mode 100644 index 0000000..eb69ea1 --- /dev/null +++ b/port/default/port_from_tf/statusor_internals.h @@ -0,0 +1,250 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_INTERNALS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_INTERNALS_H_ + +#include "port/default/port_from_tf/macros.h" +#include "port/default/port_from_tf/status.h" +#include "port/default/unreachable.h" + +namespace platforms { +namespace darwinn { +namespace util { +namespace internal_statusor { + +class Helper { + public: + // Move type-agnostic error handling to the .cc. + static void HandleInvalidStatusCtorArg(Status*); + ATTRIBUTE_NORETURN static void Crash(const Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) unreachable(); +#endif + new (p) T(std::forward(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template +class StatusOrData { + template + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template + StatusOrData(StatusOrData&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(Status&& status) : status_(std::move(status)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + void Assign(const T& value) { + if (ok()) { + data_.~T(); + MakeValue(value); + } else { + MakeValue(value); + status_ = Status::OK(); + } + } + + void Assign(T&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::move(value)); + } else { + MakeValue(std::move(value)); + status_ = Status::OK(); + } + } + + void Assign(const Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // We make it a union to be able to initialize exactly how we need without + // waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (!ok()) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed + // argument. + template + void MakeValue(Arg&& arg) { + internal_statusor::PlacementNew(&dummy_, std::forward(arg)); + } + + // Construct the status (ie. status_) through placement new with the passed + // argument. + template + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew(&status_, + std::forward(args)...); + } +}; + +// Helper base class to allow implicitly deleted constructors and assignment +// operations in StatusOr. +// TraitsBase will explicitly delete what it can't support and StatusOr will +// inherit that behavior implicitly. +template +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = default; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = default; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = delete; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = delete; +}; + +} // namespace internal_statusor +} // namespace util +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_STATUSOR_INTERNALS_H_ diff --git a/port/default/port_from_tf/thread_annotations.h b/port/default/port_from_tf/thread_annotations.h new file mode 100644 index 0000000..819fba5 --- /dev/null +++ b/port/default/port_from_tf/thread_annotations.h @@ -0,0 +1,176 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file contains the macro definitions for thread safety +// annotations that allow the developers to document the locking policies +// of their multi-threaded code. The annotations can also help program +// analysis tools to identify potential thread safety issues. +// +// The primary documentation on these annotations is external: +// http://clang.llvm.org/docs/ThreadSafetyAnalysis.html +// +// The annotations are implemented using compiler attributes. +// Using the macros defined here instead of the raw attributes allows +// for portability and future compatibility. +// +// When referring to mutexes in the arguments of the attributes, you should +// use variable names or more complex expressions (e.g. my_object->mutex_) +// that evaluate to a concrete mutex object whenever possible. If the mutex +// you want to refer to is not in scope, you may use a member pointer +// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object. +// + +#ifndef DARWINN_PORT_DEFAULT_PORT_FROM_TF_THREAD_ANNOTATIONS_H_ +#define DARWINN_PORT_DEFAULT_PORT_FROM_TF_THREAD_ANNOTATIONS_H_ + +#if defined(__clang__) && (!defined(SWIG)) +#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x)) +#else +#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op +#endif + +// Document if a shared variable/field needs to be protected by a mutex. +// GUARDED_BY allows the user to specify a particular mutex that should be +// held when accessing the annotated variable. GUARDED_VAR indicates that +// a shared variable is guarded by some unspecified mutex, for use in rare +// cases where a valid mutex expression cannot be specified. +#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x)) +#define GUARDED_VAR // no-op + +// Document if the memory location pointed to by a pointer should be guarded +// by a mutex when dereferencing the pointer. PT_GUARDED_VAR is analagous to +// GUARDED_VAR. Note that a pointer variable to a shared memory location +// could itself be a shared variable. For example, if a shared global pointer +// q, which is guarded by mu1, points to a shared memory location that is +// guarded by mu2, q should be annotated as follows: +// int *q GUARDED_BY(mu1) PT_GUARDED_BY(mu2); +#define PT_GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded_by(x)) +#define PT_GUARDED_VAR // no-op + +// Document the acquisition order between locks that can be held +// simultaneously by a thread. For any two locks that need to be annotated +// to establish an acquisition order, only one of them needs the annotation. +// (i.e. You don't have to annotate both locks with both ACQUIRED_AFTER +// and ACQUIRED_BEFORE.) +#define ACQUIRED_AFTER(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_after(__VA_ARGS__)) + +#define ACQUIRED_BEFORE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__)) + +#define ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_capability(__VA_ARGS__)) + +#define ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__)) + +#define RELEASE(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(release_capability(__VA_ARGS__)) + +// Document a function that expects a mutex to be held prior to entry. +// The mutex is expected to be held both on entry to and exit from the +// function. +#define EXCLUSIVE_LOCKS_REQUIRED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_locks_required(__VA_ARGS__)) + +#define SHARED_LOCKS_REQUIRED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_locks_required(__VA_ARGS__)) + +// Document the locks acquired in the body of the function. These locks +// cannot be held when calling this function (for instance, when the +// mutex implementation is non-reentrant). +#define LOCKS_EXCLUDED(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(locks_excluded(__VA_ARGS__)) + +// Document a function that returns a mutex without acquiring it. For example, +// a public getter method that returns a pointer to a private mutex should +// be annotated with LOCK_RETURNED. +#define LOCK_RETURNED(x) THREAD_ANNOTATION_ATTRIBUTE__(lock_returned(x)) + +// Document if a class/type is a lockable type (such as the Mutex class). +#define LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(lockable) + +// Document if a class does RAII locking (such as the MutexLock class). +// The constructor should use LOCK_FUNCTION to specify the mutex that is +// acquired, and the destructor should use UNLOCK_FUNCTION with no arguments; +// the analysis will assume that the destructor unlocks whatever the +// constructor locked. +#define SCOPED_LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(scoped_lockable) + +// Document functions that acquire a lock in the body of a function, and do +// not release it. +#define EXCLUSIVE_LOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_lock_function(__VA_ARGS__)) + +#define SHARED_LOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_lock_function(__VA_ARGS__)) + +// Document functions that expect a lock to be held on entry to the function, +// and release it in the body of the function. +#define UNLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(unlock_function(__VA_ARGS__)) + +// Document functions that try to acquire a lock, and return success or failure +// (or a non-boolean value that can be interpreted as a boolean). +// The first argument should be true for functions that return true on success, +// or false for functions that return false on success. The second argument +// specifies the mutex that is locked on success. If unspecified, it is assumed +// to be 'this'. +#define EXCLUSIVE_TRYLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(exclusive_trylock_function(__VA_ARGS__)) + +#define SHARED_TRYLOCK_FUNCTION(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(shared_trylock_function(__VA_ARGS__)) + +// Document functions that dynamically check to see if a lock is held, and fail +// if it is not held. +#define ASSERT_EXCLUSIVE_LOCK(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_exclusive_lock(__VA_ARGS__)) + +#define ASSERT_SHARED_LOCK(...) \ + THREAD_ANNOTATION_ATTRIBUTE__(assert_shared_lock(__VA_ARGS__)) + +// Turns off thread safety checking within the body of a particular function. +// This is used as an escape hatch for cases where either (a) the function +// is correct, but the locking is more complicated than the analyzer can handle, +// or (b) the function contains race conditions that are known to be benign. +#define NO_THREAD_SAFETY_ANALYSIS \ + THREAD_ANNOTATION_ATTRIBUTE__(no_thread_safety_analysis) + +// TS_UNCHECKED should be placed around lock expressions that are not valid +// C++ syntax, but which are present for documentation purposes. These +// annotations will be ignored by the analysis. +#define TS_UNCHECKED(x) "" + +namespace platforms { +namespace darwinn { +namespace thread_safety_analysis { + +// Takes a reference to a guarded data member, and returns an unguarded +// reference. +template +inline const T& ts_unchecked_read(const T& v) NO_THREAD_SAFETY_ANALYSIS { + return v; +} + +template +inline T& ts_unchecked_read(T& v) NO_THREAD_SAFETY_ANALYSIS { + return v; +} +} // namespace thread_safety_analysis +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_PORT_FROM_TF_THREAD_ANNOTATIONS_H_ diff --git a/port/default/ptr_util.h b/port/default/ptr_util.h new file mode 100644 index 0000000..3f67518 --- /dev/null +++ b/port/default/ptr_util.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_PTR_UTIL_H_ +#define DARWINN_PORT_DEFAULT_PTR_UTIL_H_ + +#include "port/default/port_from_tf/ptr_util.h" + +#endif // DARWINN_PORT_DEFAULT_PTR_UTIL_H_ diff --git a/port/default/semaphore.h b/port/default/semaphore.h new file mode 100644 index 0000000..9c23574 --- /dev/null +++ b/port/default/semaphore.h @@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_SEMAPHORE_H_ +#define DARWINN_PORT_DEFAULT_SEMAPHORE_H_ + +#include + +#include //NOLINT +#include //NOLINT +#include //NOLINT + +namespace platforms { +namespace darwinn { + +namespace internal { + +// Base class, not for direct use. +class Semaphore { + public: + virtual ~Semaphore() = default; + + bool Take() { + std::unique_lock lock(mutex_); + while (count_ == 0) { + cv_.wait(lock); + } + count_--; + + return true; + } + + bool Take(uint32_t timeout) { + // Generally we're running one tick per ms. + auto timeout_time = + std::chrono::system_clock::now() + std::chrono::milliseconds(timeout); + + std::unique_lock lock(mutex_); + while (count_ == 0) { + if (cv_.wait_until(lock, timeout_time) == std::cv_status::timeout) { + // Timed-out, check the predicate one last time. + if (count_ != 0) { + break; + } + + return false; + } + } + count_--; + + return true; + } + + bool Give() { + mutex_.lock(); + if (count_ < max_count_) { + count_++; + } + mutex_.unlock(); + cv_.notify_one(); + + return true; + } + + protected: + Semaphore(unsigned int max_count, unsigned int initial_count) + : count_(initial_count), max_count_(max_count) {} + + private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; + const int max_count_; +}; + +} // namespace internal + +// Public classes + +class BinarySemaphore : public internal::Semaphore { + public: + explicit BinarySemaphore(bool set = false) + : internal::Semaphore(1, set ? 1 : 0) {} +}; + +class CountingSemaphore : public internal::Semaphore { + public: + CountingSemaphore(unsigned int max_count, unsigned int initial_count) + : internal::Semaphore(max_count, initial_count) {} +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_SEMAPHORE_H_ diff --git a/port/default/status.h b/port/default/status.h new file mode 100644 index 0000000..02f11a3 --- /dev/null +++ b/port/default/status.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_STATUS_H_ +#define DARWINN_PORT_DEFAULT_STATUS_H_ + +#include "port/default/port_from_tf/status.h" + +#endif // DARWINN_PORT_STATUS_H_ diff --git a/port/default/status_macros.cc b/port/default/status_macros.cc new file mode 100644 index 0000000..8ff18c3 --- /dev/null +++ b/port/default/status_macros.cc @@ -0,0 +1,144 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/default/status_macros.h" + +#include "port/default/logging.h" +#include "port/default/strcat.h" + +namespace platforms { +namespace darwinn { +namespace util { +namespace status_macros { + +// TODO: Implement +static std::string CurrentStackTrace() { return ""; } + +static Status MakeStatus(error::Code code, const std::string& message) { + return Status(code, message); +} + +// Log the error at the given severity, optionally with a stack trace. +// If log_severity is NUM_SEVERITIES, nothing is logged. +static void LogError(const Status& status, const char* filename, int line, + int log_severity, bool should_log_stack_trace) { + if (PREDICT_TRUE(log_severity != NUM_SEVERITIES)) { + darwinn::internal::LogMessage log_message(filename, line, log_severity); + log_message << status; + if (should_log_stack_trace) { + log_message << "\n" << CurrentStackTrace(); + } + // Logging actually happens in LogMessage destructor. + } +} + +// Make a Status with a code, error message and payload, +// and also send it to LOG() using the given filename +// and line (unless should_log is false, or log_severity is +// NUM_SEVERITIES). If should_log_stack_trace is true, the stack +// trace is included in the log message (ignored if should_log is +// false). +static Status MakeError(const char* filename, int line, error::Code code, + const std::string& message, bool should_log, + int log_severity, bool should_log_stack_trace) { + if (PREDICT_FALSE(code == error::OK)) { + DCHECK(false) << "Cannot create error with status OK"; + code = error::UNKNOWN; + } + const Status status = MakeStatus(code, message); + if (PREDICT_TRUE(should_log)) { + LogError(status, filename, line, log_severity, should_log_stack_trace); + } + return status; +} + +// This method is written out-of-line rather than in the header to avoid +// generating a lot of inline code for error cases in all callers. +void MakeErrorStream::CheckNotDone() const { impl_->CheckNotDone(); } + +MakeErrorStream::Impl::Impl(const char* file, int line, error::Code code, + MakeErrorStream* error_stream, + bool is_logged_by_default) + : file_(file), + line_(line), + code_(code), + is_done_(false), + should_log_(is_logged_by_default), + log_severity_(ERROR), + should_log_stack_trace_(false), + make_error_stream_with_output_wrapper_(error_stream) {} + +MakeErrorStream::Impl::Impl(const Status& status, + PriorMessageHandling prior_message_handling, + const char* file, int line, + MakeErrorStream* error_stream) + : file_(file), + line_(line), + // Make sure we show some error, even if the call is incorrect. + code_(!status.ok() ? status.code() : error::UNKNOWN), + prior_message_handling_(prior_message_handling), + prior_message_(status.error_message()), + is_done_(false), + // Error code type is not visible here, so we can't call + // IsLoggedByDefault. + should_log_(true), + log_severity_(ERROR), + should_log_stack_trace_(false), + make_error_stream_with_output_wrapper_(error_stream) { + DCHECK(!status.ok()) << "Attempted to append/prepend error text to status OK"; +} + +MakeErrorStream::Impl::~Impl() { + // Note: error messages refer to the public MakeErrorStream class. + + DCHECK(is_done_) << "MakeErrorStream destructed without getting Status: " + << file_ << ":" << line_ << " " << stream_.str(); +} + +Status MakeErrorStream::Impl::GetStatus() { + // Note: error messages refer to the public MakeErrorStream class. + + // Getting a Status object out more than once is not harmful, but + // it doesn't match the expected pattern, where the stream is constructed + // as a temporary, loaded with a message, and then casted to Status. + DCHECK(!is_done_) << "MakeErrorStream got Status more than once: " << file_ + << ":" << line_ << " " << stream_.str(); + + is_done_ = true; + + const std::string& stream_str = stream_.str(); + const std::string str = prior_message_handling_ == kAppendToPriorMessage + ? StrCat(prior_message_, stream_str) + : StrCat(stream_str, prior_message_); + if (PREDICT_FALSE(str.empty())) { + return MakeError( + file_, line_, code_, + StrCat(str, "Error without message at ", file_, ":", line_), + true /* should_log */, ERROR /* log_severity */, + should_log_stack_trace_); + } else { + return MakeError(file_, line_, code_, str, should_log_, log_severity_, + should_log_stack_trace_); + } +} + +void MakeErrorStream::Impl::CheckNotDone() const { + DCHECK(!is_done_) << "MakeErrorStream shift called after getting Status: " + << file_ << ":" << line_ << " " << stream_.str(); +} + +} // namespace status_macros +} // namespace util +} // namespace darwinn +} // namespace platforms diff --git a/port/default/status_macros.h b/port/default/status_macros.h new file mode 100644 index 0000000..0ede878 --- /dev/null +++ b/port/default/status_macros.h @@ -0,0 +1,243 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_STATUS_MACROS_H_ +#define DARWINN_PORT_DEFAULT_STATUS_MACROS_H_ + +#include +#include // NOLINT +#include +#include +#include + +#include "port/default/error_codes.h" +#include "port/default/logging.h" +#include "port/default/macros.h" +#include "port/default/status.h" +#include "port/default/statusor.h" + +namespace platforms { +namespace darwinn { +namespace util { +namespace status_macros { + +// Stream object used to collect error messages in MAKE_ERROR macros +// or append error messages with APPEND_ERROR. It accepts any +// arguments with operator<< to build an error string, and then has an +// implicit cast operator to Status, which converts the +// logged string to a Status object and returns it, after logging the +// error. At least one call to operator<< is required; a compile time +// error will be generated if none are given. Errors will only be +// logged by default for certain status codes, as defined in +// IsLoggedByDefault. This class will give DFATAL errors if you don't +// retrieve a Status exactly once before destruction. +// +// The class converts into an intermediate wrapper object +// MakeErrorStreamWithOutput to check that the error stream gets at least one +// item of input. +class MakeErrorStream { + public: + // Wrapper around MakeErrorStream that only allows for output. This + // is created as output of the first operator<< call on + // MakeErrorStream. The bare MakeErrorStream does not have a + // Status operator. The net effect of that is that you + // have to call operator<< at least once or else you'll get a + // compile time error. + class MakeErrorStreamWithOutput { + public: + explicit MakeErrorStreamWithOutput(MakeErrorStream* error_stream) + : wrapped_error_stream_(error_stream) {} + + template + MakeErrorStreamWithOutput& operator<<(const T& value) { + *wrapped_error_stream_ << value; + return *this; + } + + // Implicit cast operators to Status and StatusOr. + // Exactly one of these must be called exactly once before destruction. + operator Status() { return wrapped_error_stream_->GetStatus(); } + template + operator StatusOr() { + return wrapped_error_stream_->GetStatus(); + } + + private: + MakeErrorStream* wrapped_error_stream_; + + DISALLOW_COPY_AND_ASSIGN(MakeErrorStreamWithOutput); + }; + + // When starting from an existing error status, this determines whether we'll + // append or prepend to that status's error message. + enum PriorMessageHandling { kAppendToPriorMessage, kPrependToPriorMessage }; + + // Make an error with the given code. + template + MakeErrorStream(const char* file, int line, ERROR_CODE_TYPE code) + : impl_(new Impl(file, line, code, this, true)) {} + + template + MakeErrorStreamWithOutput& operator<<(const T& value) { + CheckNotDone(); + impl_->stream_ << value; + return impl_->make_error_stream_with_output_wrapper_; + } + + // When this message is logged (see with_logging()), include the stack trace. + MakeErrorStream& with_log_stack_trace() { + impl_->should_log_stack_trace_ = true; + return *this; + } + + // Adds RET_CHECK failure text to error message. + MakeErrorStreamWithOutput& add_ret_check_failure(const char* condition) { + return *this << "RET_CHECK failure (" << impl_->file_ << ":" << impl_->line_ + << ") " << condition << " "; + } + + private: + class Impl { + public: + Impl(const char* file, int line, error::Code code, + MakeErrorStream* error_stream, bool is_logged_by_default = true); + Impl(const Status& status, PriorMessageHandling prior_message_handling, + const char* file, int line, MakeErrorStream* error_stream); + + ~Impl(); + + // This must be called exactly once before destruction. + Status GetStatus(); + + void CheckNotDone() const; + + private: + const char* file_; + int line_; + error::Code code_; + + PriorMessageHandling prior_message_handling_ = kAppendToPriorMessage; + std::string prior_message_; + bool is_done_; // true after Status object has been returned + std::ostringstream stream_; + bool should_log_; + int log_severity_; + bool should_log_stack_trace_; + + // Wrapper around the MakeErrorStream object that has a + // Status conversion. The first << operator called on + // MakeErrorStream will return this object, and only this object + // can implicitly convert to Status. The net effect of + // this is that you'll get a compile time error if you call + // MAKE_ERROR etc. without adding any output. + MakeErrorStreamWithOutput make_error_stream_with_output_wrapper_; + + friend class MakeErrorStream; + DISALLOW_COPY_AND_ASSIGN(Impl); + }; + + void CheckNotDone() const; + + // Returns the status. Used by MakeErrorStreamWithOutput. + Status GetStatus() const { return impl_->GetStatus(); } + + // Store the actual data on the heap to reduce stack frame sizes. + std::unique_ptr impl_; + + DISALLOW_COPY_AND_ASSIGN(MakeErrorStream); +}; + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + explicit StatusAdaptorForMacros(Status status) : status_(std::move(status)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return PREDICT_TRUE(status_.ok()); } + + Status&& Consume() { return std::move(status_); } + + private: + Status status_; +}; + +} // namespace status_macros +} // namespace util +} // namespace darwinn +} // namespace platforms + +#define RET_CHECK(condition) \ + while (PREDICT_FALSE(!(condition))) \ + return ::platforms::darwinn::util::status_macros::MakeErrorStream( \ + __FILE__, __LINE__, ::platforms::darwinn::util::error::INTERNAL) \ + .with_log_stack_trace() \ + .add_ret_check_failure(#condition) + +#define ASSIGN_OR_ASSERT_OK(lhs, rexpr) \ + ASSIGN_OR_ASSERT_OK_IMPL( \ + STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr); + +#define ASSIGN_OR_ASSERT_OK_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ + lhs = statusor.ConsumeValueOrDie() + +#define STATUS_MACROS_CONCAT_NAME(x, y) STATUS_MACROS_CONCAT_IMPL(x, y) +#define STATUS_MACROS_CONCAT_IMPL(x, y) x##y + +#if defined(_WIN32) +#define ASSIGN_OR_RETURN(_1, _2, ...) ASSIGN_OR_RETURN_IMPL_2(_1, _2) +#else +#define ASSIGN_OR_RETURN(...) \ + STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, ASSIGN_OR_RETURN_IMPL_3, \ + ASSIGN_OR_RETURN_IMPL_2) \ + (__VA_ARGS__) + +#define STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME +#endif + +#define ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) + +#define ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \ + ASSIGN_OR_RETURN_IMPL( \ + STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) + +#define ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ + lhs = std::move(statusor.ValueOrDie()) + +// For propagating errors when calling a function. +#define RETURN_IF_ERROR(expr) \ + do { \ + const ::platforms::darwinn::util::Status _status = (expr); \ + if (PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + +#define RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ + do { \ + ::platforms::darwinn::util::Status _status = (expr); \ + if (PREDICT_FALSE(!_status.ok())) { \ + ::platforms::darwinn::util::error::AppendToMessage(&_status, \ + __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#endif // DARWINN_PORT_DEFAULT_STATUS_MACROS_H_ diff --git a/port/default/statusor.h b/port/default/statusor.h new file mode 100644 index 0000000..d9f7d97 --- /dev/null +++ b/port/default/statusor.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_STATUSOR_H_ +#define DARWINN_PORT_DEFAULT_STATUSOR_H_ + +#include "port/default/port_from_tf/statusor.h" + +#endif // DARWINN_PORT_STATUSOR_H_ diff --git a/port/default/strcat.h b/port/default/strcat.h new file mode 100644 index 0000000..c39a819 --- /dev/null +++ b/port/default/strcat.h @@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STRCAT_H_ +#define DARWINN_PORT_STRCAT_H_ + +#include +#include + +namespace platforms { +namespace darwinn { + +template +std::string StrCat(Args const&... args) { + std::ostringstream stream; + int temp[]{0, ((void)(stream << args), 0)...}; + (void)temp; + return stream.str(); +} + +template +void StrAppend(std::string* dest, Args const&... args) { + dest->append(StrCat(args...)); // NOLINT +} + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_STRCAT_H_ diff --git a/port/default/stringprintf.cc b/port/default/stringprintf.cc new file mode 100644 index 0000000..0f6daa5 --- /dev/null +++ b/port/default/stringprintf.cc @@ -0,0 +1,97 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2002 and onwards Google Inc. +// Author: Sanjay Ghemawat + +#include "port/default/stringprintf.h" + +#include +#include // For va_list and related operations +#include // MSVC requires this for _vsnprintf +#include + +namespace platforms { +namespace darwinn { + +#ifdef COMPILER_MSVC +enum { IS_COMPILER_MSVC = 1 }; +#else +enum { IS_COMPILER_MSVC = 0 }; +#endif + +// Lower-level routine that takes a va_list and appends to a specified +// string. All other routines are just convenience wrappers around it. +void StringAppendV(std::string* dst, const char* format, va_list ap) { + // First try with a small fixed size buffer + static const int kSpaceLength = 1024; + char space[kSpaceLength]; + + // It's possible for methods that use a va_list to invalidate the data in it + // upon use. The fix is to make a copy of the structure before using it and + // use that copy instead. + va_list backup_ap; + va_copy(backup_ap, ap); + int result = vsnprintf(space, kSpaceLength, format, backup_ap); + va_end(backup_ap); + + if (result < kSpaceLength) { + if (result >= 0) { + // Normal case -- everything fit. + dst->append(space, result); + return; + } + + if (IS_COMPILER_MSVC) { + // Error or MSVC running out of space. MSVC 8.0 and higher can be asked + // about space needed with the special idiom below: + va_copy(backup_ap, ap); + result = vsnprintf(nullptr, 0, format, backup_ap); + va_end(backup_ap); + } + + if (result < 0) { + // Just an error. + return; + } + } + + // Increase the buffer size to the size requested by vsnprintf, plus one for + // the closing \0. + int length = result + 1; + char* buf = new char[length]; + + // Restore the va_list before we use it again + va_copy(backup_ap, ap); + result = vsnprintf(buf, length, format, backup_ap); + va_end(backup_ap); + + if (result >= 0 && result < length) { + // It fit + dst->append(buf, result); + } + delete[] buf; +} + +std::string StringPrintf(const char* format, ...) { + va_list ap; + va_start(ap, format); + std::string result; + StringAppendV(&result, format, ap); + va_end(ap); + return result; +} + +} // namespace darwinn +} // namespace platforms diff --git a/port/default/stringprintf.h b/port/default/stringprintf.h new file mode 100644 index 0000000..f205eb1 --- /dev/null +++ b/port/default/stringprintf.h @@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2002 and onwards Google Inc. +// Author: Sanjay Ghemawat +// +// Printf variants that place their output in a C++ string. +// +// Usage: +// string result = StringPrintf("%d %s\n", 10, "hello"); +// +// While StringF and StreamF are recommended for use, they are difficult to +// port. Fallback to StringPrintf as an alternative as it is easier to port. + +#ifndef DARWINN_PORT_DEFAULT_STRINGPRINTF_H_ +#define DARWINN_PORT_DEFAULT_STRINGPRINTF_H_ + +#include +#include +#include + +#include "port/default/macros.h" + +namespace platforms { +namespace darwinn { + +// Returns a C++ string +std::string StringPrintf(const char* format, ...) + // Tell the compiler to do printf format string checking. + PRINTF_ATTRIBUTE(1,2); + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_DEFAULT_STRINGPRINTF_H_ diff --git a/port/default/thread_annotations.h b/port/default/thread_annotations.h new file mode 100644 index 0000000..7c3e201 --- /dev/null +++ b/port/default/thread_annotations.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_THREAD_ANNOTATIONS_H_ +#define DARWINN_PORT_DEFAULT_THREAD_ANNOTATIONS_H_ + +#include "port/default/port_from_tf/thread_annotations.h" + +#endif // DARWINN_PORT_DEFAULT_THREAD_ANNOTATIONS_H_ diff --git a/port/default/unreachable.h b/port/default/unreachable.h new file mode 100644 index 0000000..d5f4034 --- /dev/null +++ b/port/default/unreachable.h @@ -0,0 +1,24 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_UNREACHABLE_H_ +#define DARWINN_PORT_DEFAULT_UNREACHABLE_H_ + +#if defined(_WIN32) +#include "port/default/unreachable_windows.h" +#else +#include "port/default/unreachable_default.h" +#endif + +#endif // DARWINN_PORT_DEFAULT_UNREACHABLE_H_ diff --git a/port/default/unreachable_default.h b/port/default/unreachable_default.h new file mode 100644 index 0000000..a5cb782 --- /dev/null +++ b/port/default/unreachable_default.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_UNREACHABLE_DEFAULT_H_ +#define DARWINN_PORT_DEFAULT_UNREACHABLE_DEFAULT_H_ + +#define unreachable() __builtin_unreachable() + +#endif // DARWINN_PORT_DEFAULT_UNREACHABLE_DEFAULT_H_ diff --git a/port/default/unreachable_windows.h b/port/default/unreachable_windows.h new file mode 100644 index 0000000..41c90a7 --- /dev/null +++ b/port/default/unreachable_windows.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_UNREACHABLE_WINDOWS_H_ +#define DARWINN_PORT_DEFAULT_UNREACHABLE_WINDOWS_H_ + +__declspec(noreturn) inline void unreachable(void) { __assume(false); } + +#endif // DARWINN_PORT_DEFAULT_UNREACHABLE_WINDOWS_H_ diff --git a/port/defs.h b/port/defs.h new file mode 100644 index 0000000..cc8208c --- /dev/null +++ b/port/defs.h @@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFS_H_ +#define DARWINN_PORT_DEFS_H_ + +// Preprocessor definitions to enable/disable individual feature +// based on the selected platform: +// - DARWINN_PORT_GOOGLE3: google3 +// - DARWINN_PORT_DEFAULT: no google3 deps but third_party +// - DARWINN_PORT_ANDROID_SYSTEM: no google3 deps, use external libraries + +#if DARWINN_PORT_DEFAULT +#define DARWINN_PORT_USE_GOOGLE3 0 +#define DARWINN_PORT_USE_EXTERNAL 0 +#endif // DARWINN_PORT_DEFAULT + +#if DARWINN_PORT_FIRMWARE +#define DARWINN_PORT_USE_GOOGLE3 0 +#define DARWINN_PORT_USE_EXTERNAL 0 +#endif // DARWINN_PORT_FIRMWARE + +#if DARWINN_PORT_ANDROID_SYSTEM || DARWINN_PORT_ANDROID_EMULATOR +#define DARWINN_PORT_USE_GOOGLE3 0 +#define DARWINN_PORT_USE_EXTERNAL 1 +#endif // DARWINN_PORT_ANDROID_SYSTEM || DARWINN_PORT_ANDROID_EMULATOR + +#if DARWINN_PORT_GOOGLE3 +#define DARWINN_PORT_USE_GOOGLE3 1 +#define DARWINN_PORT_USE_EXTERNAL 0 +#endif // DARWINN_PORT_GOOGLE3 + +#endif // DARWINN_PORT_DEFS_H_ diff --git a/port/demangle.h b/port/demangle.h new file mode 100644 index 0000000..b59bee3 --- /dev/null +++ b/port/demangle.h @@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEMANGLE_H_ +#define DARWINN_PORT_DEMANGLE_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 + +#include +#include + +#include "base/demangle.h" + +// Demangle given type. On success, return true and write the +// demangled symbol name to `out`. Otherwise, return false. +// `out` is modified even if demangling is unsuccessful. +template +bool Demangle(char *out, int out_size) { + std::string mangled("_Z"); + mangled.append(typeid(T).name()); + return ::Demangle(mangled.c_str(), out, out_size); +} + +#else // !DARWINN_PORT_USE_GOOGLE3 + +// Demangle given type. On success, return true and write the +// demangled symbol name to `out`. Otherwise, return false. +// `out` is modified even if demangling is unsuccessful. +template +bool Demangle(char *out, int out_size) { + return false; +} + +#endif // !DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_DEMANGLE_H_ diff --git a/port/errors.h b/port/errors.h new file mode 100644 index 0000000..5aab328 --- /dev/null +++ b/port/errors.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_ERRORS_H_ +#define DARWINN_PORT_ERRORS_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/task/canonical_errors.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/errors.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_ERRORS_H_ diff --git a/port/gflags.h b/port/gflags.h new file mode 100644 index 0000000..cd36180 --- /dev/null +++ b/port/gflags.h @@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_GFLAGS_H_ +#define DARWINN_PORT_GFLAGS_H_ + +#include "port/defs.h" + +#if defined(DARWINN_PORT_ANDROID_SYSTEM) || \ + defined(DARWINN_PORT_ANDROID_EMULATOR) +// There is no gflags implementation in Android runtime. Provide a dummy +// implementation here. +#define ABSL_FLAG(type, name, val, desc) type FLAGS_##name = val +namespace absl { + template + T GetFlag(T t) { return t; } +} +inline void ParseFlags(int argc, char* argv[]) {} +#else +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" + +inline void ParseFlags(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); +} +#endif + +#endif // DARWINN_PORT_GFLAGS_H_ diff --git a/port/integral_types.h b/port/integral_types.h new file mode 100644 index 0000000..f9f9be9 --- /dev/null +++ b/port/integral_types.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_INTEGRAL_TYPES_H_ +#define DARWINN_PORT_INTEGRAL_TYPES_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "base/integral_types.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/integral_types.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_INTEGRAL_TYPES_H_ diff --git a/port/logging.h b/port/logging.h new file mode 100644 index 0000000..cc8e387 --- /dev/null +++ b/port/logging.h @@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_LOGGING_H_ +#define DARWINN_PORT_LOGGING_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_FIRMWARE +#include "port/firmware/logging.h" +#elif DARWINN_PORT_USE_GOOGLE3 +#include "base/logging.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/logging.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_LOGGING_H_ diff --git a/port/macros.h b/port/macros.h new file mode 100644 index 0000000..723c715 --- /dev/null +++ b/port/macros.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_MACROS_H_ +#define DARWINN_PORT_MACROS_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "base/macros.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/macros.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_MACROS_H_ diff --git a/port/math_util.h b/port/math_util.h new file mode 100644 index 0000000..4f095bc --- /dev/null +++ b/port/math_util.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_MATH_UTIL_H_ +#define DARWINN_PORT_MATH_UTIL_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/math/mathutil.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/math_util.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_MATH_UTIL_H_ diff --git a/port/mutex.h b/port/mutex.h new file mode 100644 index 0000000..79ab712 --- /dev/null +++ b/port/mutex.h @@ -0,0 +1,50 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_MUTEX_H_ +#define DARWINN_PORT_MUTEX_H_ + +#include "port/defs.h" + +// NOTE: This abstraction is brittle. Anything beyond Mutex#Lock(), +// Mutex#Unlock() is not guaranteed to work consistently across all platforms. + +#if DARWINN_PORT_FIRMWARE +#include "third_party/safertos_addons/mutex.h" + +namespace platforms { +namespace darwinn { + +using Mutex = safertos_addons::Mutex; + +} // namespace darwinn +} // namespace platforms + +#elif DARWINN_PORT_USE_GOOGLE3 +#include "absl/synchronization/mutex.h" + +namespace platforms { +namespace darwinn { + +using Mutex = absl::Mutex; + +} // namespace darwinn +} // namespace platforms + +#elif DARWINN_PORT_DEFAULT + +#include "port/default/mutex.h" + +#endif // DARWINN_PORT_DEFAULT +#endif // DARWINN_PORT_MUTEX_H_ diff --git a/port/openssl.h b/port/openssl.h new file mode 100644 index 0000000..4cca6ef --- /dev/null +++ b/port/openssl.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_OPENSSL_H_ +#define DARWINN_PORT_OPENSSL_H_ + +#include "port/defs.h" + +#endif // DARWINN_PORT_OPENSSL_H_ diff --git a/port/posix_time.cc b/port/posix_time.cc new file mode 100644 index 0000000..808cf3a --- /dev/null +++ b/port/posix_time.cc @@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/posix_time.h" + +#include +#include + +namespace platforms { +namespace darwinn { + +namespace { + +// Seconds to nanoseconds. +constexpr uint64 kSecondsToNanos = 1000ULL * 1000ULL * 1000ULL; + +} // namespace + +uint64 GetRealTimeNanos() { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return (static_cast(ts.tv_sec) * kSecondsToNanos + + static_cast(ts.tv_nsec)); +} + +uint64 GetBootTimeNanos() { + struct timespec ts; + clock_gettime(CLOCK_BOOTTIME, &ts); + return (static_cast(ts.tv_sec) * kSecondsToNanos + + static_cast(ts.tv_nsec)); +} + +} // namespace darwinn +} // namespace platforms diff --git a/port/posix_time.h b/port/posix_time.h new file mode 100644 index 0000000..1c8d959 --- /dev/null +++ b/port/posix_time.h @@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_POSIX_TIME_H_ +#define DARWINN_PORT_POSIX_TIME_H_ + +#include "port/integral_types.h" + +namespace platforms { +namespace darwinn { + +// Get the clock realtime in nanoseconds. +uint64 GetRealTimeNanos(); + +// Get the clock boottime in nanoseconds. +uint64 GetBootTimeNanos(); + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_POSIX_TIME_H_ diff --git a/port/protobuf_helper.h b/port/protobuf_helper.h new file mode 100644 index 0000000..bab33ce --- /dev/null +++ b/port/protobuf_helper.h @@ -0,0 +1,81 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_PROTOBUF_HELPER_H_ +#define DARWINN_PORT_PROTOBUF_HELPER_H_ + +#ifdef PROTOBUF_INTERNAL_IMPL + +#include "net/proto2/io/public/coded_stream.h" +#include "net/proto2/io/public/zero_copy_stream_impl.h" +#include "net/proto2/public/map.h" +#include "net/proto2/public/repeated_field.h" +#include "net/proto2/public/text_format.h" + +using ::proto2::Map; // NOLINT +using ::proto2::RepeatedField; // NOLINT +using ::proto2::RepeatedPtrField; // NOLINT +using ::proto2::TextFormat; // NOLINT +using ::proto2::io::CodedOutputStream; // NOLINT +using ::proto2::io::FileOutputStream; // NOLINT +using ::proto2::io::StringOutputStream; // NOLINT + +#else + +#include +#include +#include +#include +#include + +using ::google::protobuf::Map; // NOLINT +using ::google::protobuf::RepeatedField; // NOLINT +using ::google::protobuf::RepeatedPtrField; // NOLINT +using ::google::protobuf::TextFormat; // NOLINT +using ::google::protobuf::io::CodedOutputStream; // NOLINT +using ::google::protobuf::io::FileOutputStream; // NOLINT +using ::google::protobuf::io::StringOutputStream; // NOLINT + +#endif + +// Note that CodedOutputStream.SetSerializationDeterministic is not available +// in qt-qpr1-dev branch (but it's available in master, which has more recent +// protobuf sources), due to which the following doesn't build in the qpr +// branch. +#if !defined(DARWINN_PORT_ANDROID_SYSTEM) + +#include + +namespace platforms { +namespace darwinn { + +// Produces deterministic serializations, so that the result can be used for +// comparison and hashing. +template +std::string SerializeProto(const Message& message) { + std::string result; + StringOutputStream stream(&result); + CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + message.SerializeToCodedStream(&output); + output.Trim(); + return result; +} + +} // namespace darwinn +} // namespace platforms + +#endif // !defined(DARWINN_PORT_ANDROID_SYSTEM) + +#endif // DARWINN_PORT_PROTOBUF_HELPER_H_ diff --git a/port/ptr_util.h b/port/ptr_util.h new file mode 100644 index 0000000..081131c --- /dev/null +++ b/port/ptr_util.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_PTR_UTIL_H_ +#define DARWINN_PORT_PTR_UTIL_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/gtl/ptr_util.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/ptr_util.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_PTR_UTIL_H_ diff --git a/port/semaphore.h b/port/semaphore.h new file mode 100644 index 0000000..a02110e --- /dev/null +++ b/port/semaphore.h @@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_SEMAPHORE_H_ +#define DARWINN_PORT_SEMAPHORE_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_FIRMWARE +#include "third_party/safertos_addons/semaphore.h" + +namespace platforms { +namespace darwinn { + +using BinarySemaphore = safertos_addons::BinarySemaphore; +using CountingSemaphore = safertos_addons::CountingSemaphore; + +} // namespace darwinn +} // namespace platforms + +#else +#include "port/default/semaphore.h" + +#endif // DARWINN_PORT_DEFAULT +#endif // DARWINN_PORT_SEMAPHORE_H_ diff --git a/port/shared_mutex.cc b/port/shared_mutex.cc new file mode 100644 index 0000000..10c4994 --- /dev/null +++ b/port/shared_mutex.cc @@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/shared_mutex.h" + +namespace platforms { +namespace darwinn { + +void SharedMutex::ReadLock() { + std::unique_lock lock(mutex_); + // Waits for the write lock to be released. + cond_.wait(lock, [this]{ return !is_writing_; }); + ++reader_count_; +} + +void SharedMutex::ReadUnlock() { + std::unique_lock lock(mutex_); + --reader_count_; + if (reader_count_ == 0) { + // Notifies the writer thread after this read. We have to notify all threads + // because there may be a reader waiting behind the writers, and notify_one + // may target a reader. + cond_.notify_all(); + } +} + +void SharedMutex::WriteLock() { + std::unique_lock lock(mutex_); + // Waits for any other writer thread to finish. + cond_.wait(lock, [this]{ return !is_writing_; }); + // Indicates that a writer thread is waiting. This blocks other reader + // threads to acquire the lock. + is_writing_ = true; + // Waits for any reader thread to finish. + cond_.wait(lock, [this]{ return reader_count_ == 0; }); +} + +void SharedMutex::WriteUnlock() { + std::unique_lock lock(mutex_); + is_writing_ = false; + // Notifies all pending reader / writer threads. + cond_.notify_all(); +} + +} // namespace darwinn +} // namespace platforms diff --git a/port/shared_mutex.h b/port/shared_mutex.h new file mode 100644 index 0000000..09eb0b5 --- /dev/null +++ b/port/shared_mutex.h @@ -0,0 +1,142 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_SHARED_MUTEX_H_ +#define DARWINN_PORT_SHARED_MUTEX_H_ + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { + +// A simple implementation of a reader / writer lock. +// +// This allows concurrent reader lock access, but when a writer lock is +// acquired, all other writers and readers will be blocked till the writer +// finishes. +// These locks are not reentrant. +// +// We created this since some of our third_party build targets do not support +// C++14 yet, which is when shared mutex was added. +// +// This implementation also prevents the problem of writer starving. When a +// writer waits for the lock, no other reader can hold the lock. This allows +// the writer to get the lock in a reasonable time. +// +// It is not recommended to use the locking functions in this class directly. +// Please use the scoped wrappers described later in this file. +// +// Example: +// +// SharedMutex mu; +// mu.ReadLock(); +// (some read-only operations...) +// mu.ReadUnlock(); +// +// mu.WriteLock(); +// (some write operations...) +// mu.WriteUnlock(); +class LOCKABLE SharedMutex { + public: + SharedMutex() + : reader_count_(0), + is_writing_(false) {} + + // Blocks the thread until it acquires the lock in shared mode. + void ReadLock() SHARED_LOCK_FUNCTION(); + + // Releases the read share of this SharedMutex. + void ReadUnlock() UNLOCK_FUNCTION(); + + // Blocks the thread untill it acquires the lock exclusively. + void WriteLock() EXCLUSIVE_LOCK_FUNCTION(); + + // Releases the writer lock. + void WriteUnlock() UNLOCK_FUNCTION(); + + private: + // Internal mutex for every reader / writer to hold before proceed. + std::mutex mutex_; + + // Condition variable for every thread to wait on other threads. + std::condition_variable cond_; + + // Count of current active reader. + int reader_count_; + + // True if the writer lock is owned by some thread. + bool is_writing_; +}; + +// Wrapper for the SharedMutex class, which acquires and releases the lock +// in reader / shared mode via RAII. +// +// Example: +// SharedMutex mu; +// foo() { +// ReaderMutexLock shared_lock(&mu); +// } +class SCOPED_LOCKABLE ReaderMutexLock { + public: + explicit ReaderMutexLock(SharedMutex* mu) SHARED_LOCK_FUNCTION(mu) + : mu_(mu) { + mu_->ReadLock(); + } + + // This class is neither copyable nor movable. + ReaderMutexLock(const ReaderMutexLock&) = delete; + ReaderMutexLock& operator=(const ReaderMutexLock&) = delete; + + ~ReaderMutexLock() UNLOCK_FUNCTION() { + mu_->ReadUnlock(); + } + + private: + SharedMutex * const mu_; +}; + +// Wrapper for the SharedMutex class, which acquires and releases the lock +// in writer / exclusive mode via RAII. +// +// Example: +// SharedMutex mu; +// foo() { +// WriterMutexLock exclusive_lock(&mu); +// } +class SCOPED_LOCKABLE WriterMutexLock { + public: + explicit WriterMutexLock(SharedMutex* mu) EXCLUSIVE_LOCK_FUNCTION(mu) + : mu_(mu) { + mu_->WriteLock(); + } + + // This class is neither copyable nor movable. + WriterMutexLock(const WriterMutexLock&) = delete; + WriterMutexLock& operator=(const WriterMutexLock&) = delete; + + ~WriterMutexLock() UNLOCK_FUNCTION() { + mu_->WriteUnlock(); + } + + private: + SharedMutex * const mu_; +}; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_SHARED_MUTEX_H_ diff --git a/port/status.h b/port/status.h new file mode 100644 index 0000000..749ee57 --- /dev/null +++ b/port/status.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STATUS_H_ +#define DARWINN_PORT_STATUS_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/task/status.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/status.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +namespace platforms { +namespace darwinn { + +#if DARWINN_PORT_USE_GOOGLE3 +using StatusCode = ::absl::StatusCode; +#else // !DARWINN_PORT_USE_GOOGLE3 +using StatusCode = util::error::Code; +#endif // DARWINN_PORT_USE_GOOGLE3 + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_STATUS_H_ diff --git a/port/status_macros.h b/port/status_macros.h new file mode 100644 index 0000000..db64418 --- /dev/null +++ b/port/status_macros.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STATUS_MACROS_H_ +#define DARWINN_PORT_STATUS_MACROS_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/task/status_macros.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/status_macros.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_STATUS_MACROS_H_ diff --git a/port/statusor.h b/port/statusor.h new file mode 100644 index 0000000..763adde --- /dev/null +++ b/port/statusor.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STATUSOR_H_ +#define DARWINN_PORT_STATUSOR_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "util/task/statusor.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/statusor.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_STATUSOR_H_ diff --git a/port/std_mutex_lock.h b/port/std_mutex_lock.h new file mode 100644 index 0000000..a466d26 --- /dev/null +++ b/port/std_mutex_lock.h @@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STD_MUTEX_LOCK_H_ +#define DARWINN_PORT_STD_MUTEX_LOCK_H_ + +#include // NOLINT + +#include "port/thread_annotations.h" + +namespace platforms { +namespace darwinn { + +// A wrapper around std::mutex lockers to enable thread annotations. The +// constructor takes a pointer to a mutex, which resembles MutexLock interface. +template +class SCOPED_LOCKABLE AnnotatedStdMutexLock : public T { + public: + explicit AnnotatedStdMutexLock(std::mutex* mu) EXCLUSIVE_LOCK_FUNCTION(mu) + : T(*mu) {} + + // This class is neither copyable nor movable. + AnnotatedStdMutexLock(const AnnotatedStdMutexLock&) = delete; + AnnotatedStdMutexLock& operator=(const AnnotatedStdMutexLock&) = delete; + + ~AnnotatedStdMutexLock() UNLOCK_FUNCTION() = default; +}; + +// Intended to be used as a direct replacement of ReaderMutexLock/MutexLock. The +// mutex is locked when constructed, and unlocked when destructed. +typedef AnnotatedStdMutexLock> StdMutexLock; + +// Intended to be used as a direct replacement of ReaderMutexLock/MutexLock only +// when std::condition_variable is used with the mutex. Use StdMutexLock +// otherwise. The mutex is locked when constructed, and unlocked when +// destructed. +typedef AnnotatedStdMutexLock> StdCondMutexLock; + +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_STD_MUTEX_LOCK_H_ diff --git a/port/string_util.h b/port/string_util.h new file mode 100644 index 0000000..7666e44 --- /dev/null +++ b/port/string_util.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_STRING_UTIL_H_ +#define DARWINN_PORT_STRING_UTIL_H_ + +#include "port/default/strcat.h" + +#endif // DARWINN_PORT_STRING_UTIL_H_ diff --git a/port/stringprintf.h b/port/stringprintf.h new file mode 100644 index 0000000..e98f6d4 --- /dev/null +++ b/port/stringprintf.h @@ -0,0 +1,27 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO Move this to string_util.h. +#ifndef DARWINN_PORT_STRINGPRINTF_H_ +#define DARWINN_PORT_STRINGPRINTF_H_ + +#include "port/defs.h" + +#if DARWINN_PORT_USE_GOOGLE3 +#include "base/stringprintf.h" +#else // !DARWINN_PORT_USE_GOOGLE3 +#include "port/default/stringprintf.h" +#endif // DARWINN_PORT_USE_GOOGLE3 + +#endif // DARWINN_PORT_STRINGPRINTF_H_ diff --git a/port/thread_annotations.h b/port/thread_annotations.h new file mode 100644 index 0000000..52e2256 --- /dev/null +++ b/port/thread_annotations.h @@ -0,0 +1,22 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_THREAD_ANNOTATIONS_H_ +#define DARWINN_PORT_THREAD_ANNOTATIONS_H_ + +#include "port/defs.h" +// Built for google3 or portable (Android). +#include "port/default/thread_annotations.h" + +#endif // DARWINN_PORT_THREAD_ANNOTATIONS_H_ diff --git a/port/time.h b/port/time.h new file mode 100644 index 0000000..beebf09 --- /dev/null +++ b/port/time.h @@ -0,0 +1,79 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_TIME_H_ +#define DARWINN_PORT_TIME_H_ + +#include "port/defs.h" +#include "port/integral_types.h" + +#if DARWINN_PORT_USE_GOOGLE3 + +#include "absl/time/clock.h" +#include "absl/time/time.h" + +// Returns the current timestamp in nanoseconds. +inline int64 GetCurrentTimeNanos() { + return absl::GetCurrentTimeNanos(); +} + +// Returns the current timestamp in microseconds. +inline int64 GetCurrentTimeMicros() { + return reinterpret_cast(absl::GetCurrentTimeNanos() / 1000); +} + +// Sleep for the specified amount of seconds. +inline void Sleep(int seconds) { + absl::SleepFor(absl::Seconds(seconds)); +} + +// Sleep for the specified amount of microseconds. +inline void Microsleep(int microseconds) { + absl::SleepFor(absl::Microseconds(microseconds)); +} + +#else // !DARWINN_PORT_USE_GOOGLE3 + + +#include // NOLINT +#include // NOLINT + +// Returns the current timestamp in nanoseconds. +inline platforms::darwinn::int64 GetCurrentTimeNanos() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +// Returns the current timestamp in microseconds. +inline platforms::darwinn::int64 GetCurrentTimeMicros() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +// Sleep for the specified amount of seconds. +inline void Sleep(int seconds) { + std::this_thread::sleep_for(std::chrono::seconds(seconds)); +} + +// Sleep for the specified amount of microseconds. +inline void Microsleep(int microseconds) { + std::this_thread::sleep_for(std::chrono::microseconds(microseconds)); +} + +#endif // DARWINN_PORT_USE_GOOGLE3 + + +#endif // DARWINN_PORT_TIME_H_ diff --git a/port/timer.h b/port/timer.h new file mode 100644 index 0000000..f4d3186 --- /dev/null +++ b/port/timer.h @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_TIMER_H_ +#define DARWINN_PORT_TIMER_H_ + +#if defined(_WIN32) +#include "port/timer_windows.h" +#elif defined(__APPLE__) +#include "port/timer_darwin.h" +#else +#include "port/timer_linux.h" +#endif // defined(_WIN32) + +#endif // DARWINN_PORT_TIMER_H_ diff --git a/port/timer_darwin.cc b/port/timer_darwin.cc new file mode 100644 index 0000000..af486ae --- /dev/null +++ b/port/timer_darwin.cc @@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/timer_darwin.h" + +namespace platforms { +namespace darwinn { +namespace api { + +util::Status Timer::Set(int64 nanos) { + std::lock_guard guard(mutex_); + + deadline_ = nanos == 0 ? Clock::time_point{Clock::duration::max()} + : Clock::now() + std::chrono::nanoseconds(nanos); + deadline_set_.notify_all(); + return util::OkStatus(); +} + +util::StatusOr Timer::Wait() { + while (true) { + std::unique_lock lock(mutex_); + auto now = Clock::now(); + if (now >= deadline_ || deadline_set_.wait_for(lock, deadline_ - now) == + std::cv_status::timeout) + return 1; + } +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/port/timer_darwin.h b/port/timer_darwin.h new file mode 100644 index 0000000..06f949c --- /dev/null +++ b/port/timer_darwin.h @@ -0,0 +1,62 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_TIMER_DARWIN_H_ +#define DARWINN_PORT_TIMER_DARWIN_H_ + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// A simple interface for countdown timers. +class Timer { + public: + Timer() = default; + virtual ~Timer() = default; + + // This class is neither copyable nor movable. + Timer(const Timer&) = delete; + Timer& operator=(const Timer&) = delete; + + // Sets the timer to the specified nanoseconds. Countdown immediately starts + // after setting. Setting to 0 will de-activate the timer. + virtual util::Status Set(int64 nanos); + + // Waits for the timer to reach 0 and returns. If timer is de-activated before + // reaching 0 or never activated, this call will never return. + virtual util::StatusOr Wait(); + private: + using Clock = std::chrono::steady_clock; + // Deadline until Wait() method is blocked. + Clock::time_point deadline_{Clock::duration::max()}; + // Mutex which guards condition variable. + std::mutex mutex_; + // Condition variable to wait on (until deadline is reached). + std::condition_variable deadline_set_; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + + +#endif // DARWINN_PORT_TIMER_DARWIN_H_ diff --git a/port/timer_linux.cc b/port/timer_linux.cc new file mode 100644 index 0000000..351c3c8 --- /dev/null +++ b/port/timer_linux.cc @@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/timer.h" + +#include +#include + +#include "port/errors.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace api { + +Timer::Timer() { + fd_ = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC); + CHECK_GE(fd_, 0) << StringPrintf("Failed to create timerfd: %s", + strerror(errno)); +} + +Timer::~Timer() { close(fd_); } + +util::Status Timer::Set(int64 nanos) { + itimerspec spec = { + .it_interval = + { + .tv_sec = 0, + .tv_nsec = 0, + }, + .it_value = + { + .tv_sec = static_cast(nanos / 1000000000), + .tv_nsec = static_cast(nanos % 1000000000), + }, + }; + + int return_code = timerfd_settime(fd_, 0, &spec, nullptr); + if (return_code != 0) { + return util::InternalError( + StringPrintf("Failed to set timer: %s", strerror(errno))); + } + + return util::OkStatus(); +} + +util::StatusOr Timer::Wait() { + uint64 expirations; + size_t bytes_read = read(fd_, &expirations, sizeof(uint64)); + if (errno == EINTR) { + return 0; + } + if (bytes_read != sizeof(uint64)) { + return util::InternalError(StringPrintf( + "Timer read failed (%zu bytes read): %s", bytes_read, strerror(errno))); + } + + return expirations; +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/port/timer_linux.h b/port/timer_linux.h new file mode 100644 index 0000000..a124a1b --- /dev/null +++ b/port/timer_linux.h @@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_TIMER_LINUX_H_ +#define DARWINN_PORT_TIMER_LINUX_H_ + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// A simple interface for countdown timers. +class Timer { + public: + Timer(); + virtual ~Timer(); + + // This class is neither copyable nor movable. + Timer(const Timer&) = delete; + Timer& operator=(const Timer&) = delete; + + // Sets the timer to the specified nanoseconds. Countdown immediately starts + // after setting. Setting to 0 will de-activate the timer. + virtual util::Status Set(int64 nanos); + + // Waits for the timer to reach 0 and returns. If timer is de-activated before + // reaching 0 or never activated, this call will never return. + virtual util::StatusOr Wait(); + private: + // File handle for timerfd. + int fd_; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_TIMER_LINUX_H_ diff --git a/port/timer_windows.cc b/port/timer_windows.cc new file mode 100644 index 0000000..9cea09e --- /dev/null +++ b/port/timer_windows.cc @@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "port/timer.h" + +#include +#include + +#include "port/errors.h" +#include "port/stringprintf.h" + +namespace platforms { +namespace darwinn { +namespace api { + +Timer::Timer() { + timer_handle_ = CreateWaitableTimer(NULL, FALSE, NULL); + CHECK(timer_handle_) << StringPrintf("CreateWaitableTimer failed: %d", + GetLastError()); +} + +Timer::~Timer() { CloseHandle(timer_handle_); } + +util::Status Timer::Set(int64 nanos) { + // Negative times represent relative times. + // SetWaitableTimer units are 100ns. + LARGE_INTEGER dueTime; + dueTime.QuadPart = -(nanos / 100); + bool ret = SetWaitableTimer(timer_handle_, &dueTime, 0, NULL, NULL, NULL); + if (!ret) { + return util::InternalError( + StringPrintf("Failed to set timer: %d", GetLastError())); + } + return util::OkStatus(); +} + +util::StatusOr Timer::Wait() { + if (WaitForSingleObject(timer_handle_, INFINITE) == WAIT_OBJECT_0) { + return 1; + } + return 0; +} + +} // namespace api +} // namespace darwinn +} // namespace platforms diff --git a/port/timer_windows.h b/port/timer_windows.h new file mode 100644 index 0000000..b42898e --- /dev/null +++ b/port/timer_windows.h @@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_TIMER_WINDOWS_H_ +#define DARWINN_PORT_TIMER_WINDOWS_H_ + +#include "port/integral_types.h" +#include "port/status.h" +#include "port/statusor.h" + +namespace platforms { +namespace darwinn { +namespace api { + +// A simple interface for countdown timers. +class Timer { + public: + Timer(); + virtual ~Timer(); + + // This class is neither copyable nor movable. + Timer(const Timer&) = delete; + Timer& operator=(const Timer&) = delete; + + // Sets the timer to the specified nanoseconds. Countdown immediately starts + // after setting. Setting to 0 will de-activate the timer. + virtual util::Status Set(int64 nanos); + + // Waits for the timer to reach 0 and returns. If timer is de-activated before + // reaching 0 or never activated, this call will never return. + virtual util::StatusOr Wait(); + private: + // Handle for WaitableTimer. + void* timer_handle_; +}; + +} // namespace api +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_PORT_TIMER_WINDOWS_H_ diff --git a/port/tracing.h b/port/tracing.h new file mode 100644 index 0000000..88e7cd8 --- /dev/null +++ b/port/tracing.h @@ -0,0 +1,141 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_DEFAULT_SYSTRACE_H_ +#define DARWINN_PORT_DEFAULT_SYSTRACE_H_ + +#include + +#define DARWINN_SCOPE_PREFIX "DarwiNN::" + +// For Android binaries built on Android. +// We borrow the NNAPI systrace implementation to profile DarwiNN drivers built +// on Android. +#if defined(DARWINN_PORT_ANDROID_SYSTEM) + +#include "Tracing.h" + +#define TRACE_INITIALIZE() +// Use this only once per function, ideally at the beginning of each scope. +#define TRACE_SCOPE(name) NNTRACE_NAME_1(DARWINN_SCOPE_PREFIX name) + +// Use this to add trace markers in a scope that already has a TRACE_SCOPE. +#define TRACE_WITHIN_SCOPE(name) NNTRACE_NAME_SWITCH(DARWINN_SCOPE_PREFIX name) + +// Use this when a new thread starts up. +#define TRACE_START_THREAD(name) + +#define TRACE_DUMP(output_file) +#define TRACE_FINALIZE() + +// For Android binaries built on google3. +// When building on google3, blaze will not be able to link against the atrace +// symbols and we will need to dynamically link to it by ourselves. +#elif defined(__ANDROID__) && defined(DARWINN_ANDROID_GOOGLE3_TRACE_ENABLED) + +#include "port/tracer/darwinn_android_tracer.h" + +#define TRACE_SCOPE(name) \ + ::platforms::darwinn::DARWINN_ANDROID_TRACE_SCOPE(DARWINN_SCOPE_PREFIX name) + +#define TRACE_WITHIN_SCOPE(name) \ + ::platforms::darwinn::DARWINN_ANDROID_TRACE_SCOPE(DARWINN_SCOPE_PREFIX name) + +#define TRACE_START_THREAD(name) +#define TRACE_DUMP(output_file) +#define TRACE_FINALIZE() + +// Perfetto tracing for firmware can be enabled at build time. +// Build the firmware with --define darwinn_firmware_trace_enabled=1 +// and build run_graph_executor with --define darwinn_perfetto_trace_enabled=1. +#elif defined(__ANDROID__) && defined(DARWINN_PERFETTO_TRACE_ENABLED) + +#include "port/tracer/darwinn_perfetto_scoped_tracer.h" +#include "port/tracer/darwinn_perfetto_tracer.h" + +#define TRACE_INITIALIZE() ::platforms::darwinn::InitializeScopedPerfetto() + +#define TRACE_SCOPE(name) \ + PERFTETTO_TRACK_SCOPE(DARWINN_SCOPE_PREFIX name) + +#define TRACE_WITHIN_SCOPE(name) \ + PERFTETTO_TRACK_SCOPE(DARWINN_SCOPE_PREFIX name) + +#define TRACE_START_THREAD(name) +#define TRACE_DUMP(output_file) + +#define TRACE_FINALIZE() ::platforms::darwinn::FinalizePerfetto() + +// Web Tracing Framework can be enabled at build time. +// --define=GLOBAL_WTF_ENABLE=1 +#elif defined(WTF_ENABLE) + +#include "third_party/tracing_framework_bindings_cpp/macros.h" // IWYU pragma: export + +#define TRACE_INITIALIZE() +#define TRACE_SCOPE(name) WTF_SCOPE0(DARWINN_SCOPE_PREFIX name) +#define TRACE_WITHIN_SCOPE(name) WTF_EVENT0(DARWINN_SCOPE_PREFIX name) +#define TRACE_START_THREAD(name) WTF_THREAD_ENABLE(DARWINN_SCOPE_PREFIX name) +#define TRACE_DUMP(output_file) +#define TRACE_FINALIZE() + +#elif defined(DARWINN_CSV_TRACE_ENABLED) + +#include "port/tracer/darwinn_csv_tracer.h" + +#define TRACE_INITIALIZE() +#define TRACE_SCOPE(name) \ + ::platforms::darwinn::DARWINN_CSV_TRACE_SCOPE(DARWINN_SCOPE_PREFIX name) +#define TRACE_WITHIN_SCOPE(name) \ + ::platforms::darwinn::DARWINN_CSV_TRACE_SCOPE(DARWINN_SCOPE_PREFIX name) + +#define TRACE_START_THREAD(name) + +#define TRACE_DUMP(output_file) \ + ::platforms::darwinn::DarwinnCSVTracer::DumpTrace(output_file) + +#define TRACE_FINALIZE() + +// If xprof tracing is enabled at build time: --define=darwinn_xprof_enabled=1 +// To capture the trace, use perftools/gputools/profiler/xprof.sh. +#elif defined(DARWINN_XPROF_ENABLED) + +#include "port/tracer/darwinn_fw_xprof_tracer.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +#define _PASTE(x, y) x##y +#define PASTE(x, y) _PASTE(x, y) + +#define TRACE_INITIALIZE() +#define TRACE_SCOPE(name) \ + tensorflow::profiler::TraceMe PASTE(activity, \ + __LINE__)(DARWINN_SCOPE_PREFIX name) +#define TRACE_WITHIN_SCOPE(name) TRACE_SCOPE(name) +#define TRACE_START_THREAD(name) TRACE_SCOPE(name) + +#define TRACE_FINALIZE() + +// No tracing for other environments. +#else + +#define TRACE_INITIALIZE() +#define TRACE_SCOPE(name) +#define TRACE_WITHIN_SCOPE(name) +#define TRACE_START_THREAD(name) +#define TRACE_DUMP(output_file) +#define TRACE_FINALIZE() + +#endif + +#endif // DARWINN_PORT_DEFAULT_SYSTRACE_H_ diff --git a/port/unreachable.h b/port/unreachable.h new file mode 100644 index 0000000..a318c04 --- /dev/null +++ b/port/unreachable.h @@ -0,0 +1,20 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_PORT_UNREACHABLE_H_ +#define DARWINN_PORT_UNREACHABLE_H_ + +#include "port/default/unreachable.h" + +#endif // DARWINN_PORT_UNREACHABLE_H_ diff --git a/rename_library.py b/rename_library.py new file mode 100644 index 0000000..776adb3 --- /dev/null +++ b/rename_library.py @@ -0,0 +1,86 @@ +# Lint as: python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Takes a Windows DLL as input, and generates a new link library to allow the original DLL to be renamed. + +Because this requires dumpbin.exe and lib.exe on the PATH, it's recommended to +run it from the Visual Studio command prompt. + +As a result of running this script, input_dll will be renamed to output_dll, +and a new companion .if.lib file will be generated to go with output_dll. + +Args: + input_dll: Path to original DLL + output_dll: Path to create new DLL +""" + +# Please note: this script does not run in google3, so limit dependencies +# to the Python standard library. + +import argparse +import os +import re +import subprocess +import sys + + +def main(): + # Use the DUMPBIN tool to extract exported symbol names. + dumpbin_out = subprocess.check_output(['dumpbin', '/exports', args.input_dll]) + + # Functions from dumpbin look like this: + # 1 0 0003F9C0 ??0EdgeTpuManager@edgetpu@@QEAA@AEBV01@@Z + matcher = re.compile(r'^\s*\d+\s+\w+\s+\w{8}\s+([^ ]+)') + + # Build a DEF file from the output of DUMPBIN. + output = 'EXPORTS' + os.linesep + for line in dumpbin_out.splitlines(): + matches = matcher.search(line.decode()) + if matches: + fn_name = matches.group(1) + output += fn_name + os.linesep + def_file_name = args.output_dll[:-4] + '.def' + lib_file_name = args.output_dll + '.if.lib' + exp_file_name = args.output_dll + '.if.exp' + with open(def_file_name, 'w') as output_def_file: + output_def_file.write(output) + + # Use the LIB tool to generate a new link library. + subprocess.check_output([ + 'lib', '/machine:x64', + '/def:%s' % os.path.basename(def_file_name), + '/out:%s' % os.path.basename(lib_file_name) + ], + cwd=os.path.dirname(args.output_dll)) + + # Move original DLL to new name. + os.rename(args.input_dll, args.output_dll) + + # Clean up intermediates. + os.remove(def_file_name) + os.remove(exp_file_name) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_dll', help='Path to library to rename', required=True) + parser.add_argument( + '--output_dll', help='Path to output renamed library', required=True) + args = parser.parse_args() + if not args.output_dll.endswith('.dll') or not args.input_dll.endswith( + '.dll'): + print('ERROR: input_dll and output_dll must end with .dll') + sys.exit(-1) + main() diff --git a/tflite/BUILD b/tflite/BUILD new file mode 100644 index 0000000..d0b54db --- /dev/null +++ b/tflite/BUILD @@ -0,0 +1,227 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "custom_op_data", + srcs = ["custom_op_data.cc"], + hdrs = ["custom_op_data.h"], + deps = [ + "//port", + "@flatbuffers", + ], +) + +cc_library( + name = "custom_op", + srcs = ["custom_op.cc"], + hdrs = ["custom_op.h"], + deps = [ + ":custom_op_data", + "//api:layer_information", + "//driver:package_registry", + "//port", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "custom_op_user_data_direct", + srcs = ["custom_op_user_data_direct.cc"], + hdrs = ["custom_op_user_data_direct.h"], + deps = [ + ":custom_op", + ":custom_op_data", + "//api:driver", + "//driver:package_registry", + "//port", + ], +) + +cc_library( + name = "edgetpu_context_direct_header", + hdrs = ["edgetpu_context_direct.h"], + deps = [ + ":custom_op_data", + ":custom_op_user_data_direct", + "//api:driver", + "//api:driver_factory", + "//api:driver_options_fbs", + "//api:package_reference", + "//driver:driver_factory", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "edgetpu_manager_direct_header", + hdrs = ["edgetpu_manager_direct.h"], + deps = [ + ":custom_op_data", + ":custom_op_user_data_direct", + ":edgetpu_context_direct_header", + "//api:driver", + "//api:driver_factory", + "//api:driver_options_fbs", + "//api:package_reference", + "//driver:driver_factory", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "edgetpu_context_direct", + srcs = ["edgetpu_context_direct.cc"], + hdrs = ["edgetpu_context_direct.h"], + deps = [ + ":custom_op_data", + ":custom_op_user_data_direct", + ":edgetpu_manager_direct_header", + "//api:driver", + "//api:driver_factory", + "//api:driver_options_fbs", + "//api:package_reference", + "//driver:driver_factory", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "edgetpu_manager_direct", + srcs = ["edgetpu_manager_direct.cc"], + hdrs = ["edgetpu_manager_direct.h"], + deps = [ + ":custom_op_data", + ":custom_op_user_data_direct", + ":edgetpu_context_direct", + "@com_google_absl//absl/strings:str_format", + "//api:driver", + "//api:driver_factory", + "//api:driver_options_fbs", + "//api:package_reference", + "//api:runtime_version", + "//driver:driver_factory", + "//port", + "//port:std_mutex_lock", + "//port:thread_annotations", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "custom_op_direct", + srcs = ["custom_op_direct.cc"], + deps = [ + ":custom_op", + ":custom_op_data", + ":custom_op_user_data_direct", + ":edgetpu_c", # buildcleaner: keep + ":edgetpu_context_direct", + ":edgetpu_delegate_for_custom_op_tflite_plugin", # buildcleaner: keep + ":edgetpu_manager_direct", + "//api:driver", + "//port", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:context", + ], + alwayslink = 1, +) + +cc_library( + name = "tensor_data_controller", + srcs = ["tensor_data_controller.cc"], + hdrs = ["tensor_data_controller.h"], + deps = [ + "//port", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + ], +) + +cc_library( + name = "edgetpu_context_factory", + srcs = ["edgetpu_context_factory.cc"], + hdrs = ["edgetpu_context_factory.h"], + deps = [ + ":edgetpu_context_direct", + "//port", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:framework", + ], +) + +# Set EXTRA_LOGGING to DUMMY to diable extra logging in text protobuf. +config_setting( + name = "no_extra_logging", + values = {"define": "EXTRA_LOGGING=DUMMY"}, +) + +cc_library( + name = "edgetpu_delegate_for_custom_op", + srcs = [ + "edgetpu_delegate_for_custom_op.cc", + ], + hdrs = ["edgetpu_delegate_for_custom_op.h"], + deps = [ + "//port", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:kernel_api", + "@org_tensorflow//tensorflow/lite:util", + ], +) + +cc_library( + name = "edgetpu_delegate_for_custom_op_tflite_plugin", + srcs = [ + "edgetpu_delegate_for_custom_op_tflite_plugin.cc", + ], + deps = [ + ":edgetpu_delegate_for_custom_op", + "@com_google_absl//absl/strings", + "//tflite/public:edgetpu", + "@org_tensorflow//tensorflow/lite:kernel_api", + ], + alwayslink = 1, +) + +cc_library( + name = "edgetpu_c", + srcs = [ + "edgetpu_c.cc", + ], + deps = [ + ":edgetpu_delegate_for_custom_op", + "//port", + "//tflite/public:edgetpu", + "//tflite/public:edgetpu_c", + ], + alwayslink = 1, +) diff --git a/tflite/custom_op.cc b/tflite/custom_op.cc new file mode 100644 index 0000000..3ec6a37 --- /dev/null +++ b/tflite/custom_op.cc @@ -0,0 +1,318 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/custom_op.h" + +#include + +#include "api/layer_information.h" +#include "driver/package_registry.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/status_macros.h" +#include "port/stringprintf.h" +#include "tflite/custom_op_data.h" +#include "tflite/public/edgetpu.h" +#include "tensorflow/lite/context.h" + +#define RETURN_IF_NOT_EQ(a, b) \ + do { \ + if ((a) != (b)) { \ + return util::InternalError(StringPrintf("%s:%d %s != %s (%d != %d)", \ + __FILE__, __LINE__, #a, #b, \ + (int)(a), (int)(b))); \ + } \ + } while (0) + +namespace platforms { +namespace darwinn { +namespace tflite { + +namespace { + +bool IsFloat32ClassifierLayer(const api::OutputLayerInformation* output_layer) { + return (output_layer->y_dim() == 1 && output_layer->x_dim() == 1 && + output_layer->data_type() == platforms::darwinn::DataType_SINGLE); +} + +bool IsUint16ClassifierLayer(const api::OutputLayerInformation* output_layer) { + return (output_layer->y_dim() == 1 && output_layer->x_dim() == 1 && + output_layer->data_type() == + platforms::darwinn::DataType_FIXED_POINT16); +} + +// Returns the number of bytes occupied by a value of the given data type. +// Note that only a subset of data types are currently supported. +util::StatusOr SizeOfDataType(TfLiteType data_type) { + switch (data_type) { + case kTfLiteUInt8: + case kTfLiteInt8: + return sizeof(uint8_t); + + case kTfLiteInt16: + return sizeof(uint16_t); + + case kTfLiteInt32: + return sizeof(uint32_t); + + case kTfLiteFloat16: + return 2; + + case kTfLiteFloat32: + return sizeof(float); + + default: + return util::InternalError(StringPrintf( + "Unsupported data type in custom op handler: %d", data_type)); + } +} + +// Checks that the DarwiNN and TFLite types are compatible. For input layers, +// they must match exactly. For output layers, they are compatible if +// ReFormatOutputs is capable of converting from the DarwiNN type to the +// corresponding TFLite type. +// Note that only a subset of data types are currently supported. +util::Status ValidateDataType( + TfLiteType tf_lite_type, darwinn::DataType darwinn_type, + const api::OutputLayerInformation* optional_output_layer) { + switch (darwinn_type) { + case DataType_FIXED_POINT8: + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteUInt8); + break; + case DataType_SIGNED_FIXED_POINT8: + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteInt8); + break; + + case DataType_FIXED_POINT16: + if (optional_output_layer != nullptr && tf_lite_type == kTfLiteUInt8 && + IsUint16ClassifierLayer(optional_output_layer)) { + return util::OkStatus(); + } + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteInt16); + break; + + case DataType_SIGNED_FIXED_POINT16: + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteInt16); + break; + + case DataType_SIGNED_FIXED_POINT32: + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteInt32); + break; + + case DataType_SINGLE: + if (optional_output_layer != nullptr && tf_lite_type == kTfLiteUInt8 && + IsFloat32ClassifierLayer(optional_output_layer)) { + return util::OkStatus(); + } + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteFloat32); + break; + + case DataType_HALF: + RETURN_IF_NOT_EQ(tf_lite_type, kTfLiteFloat16); + break; + + default: + return util::InternalError( + StringPrintf("Unsupported layer data type in custom op handler: %d", + darwinn_type)); + } + + return util::OkStatus(); +} + +// Validates input and output count, type and sizes against DarwiNN executable. +// Also resizes output tensors to the correct batch size. +util::Status ValidateInputsAndOutputs( + TfLiteContext* context, TfLiteNode* node, + const driver::ExecutableLayersInfo* executable_layers_info) { + int batches = 0; + + // Validate inputs. + CustomOpUserData* user_data = + reinterpret_cast(node->user_data); + RETURN_IF_NOT_EQ(executable_layers_info->NumInputLayers(), + user_data->GetInputs(node)->size); + + for (int i = 0; i < executable_layers_info->NumInputLayers(); ++i) { + const TfLiteTensor* input = GetInput(context, node, i); + ASSIGN_OR_RETURN(const int size_of_data_type, SizeOfDataType(input->type)); + const int single_input_size = + executable_layers_info->InputLayerSize(i) * size_of_data_type; + + // Data types must match. + RETURN_IF_ERROR(ValidateDataType( + input->type, executable_layers_info->InputLayer(i)->data_type(), + nullptr)); + + // Check for a batch dimension. The batch dimension is always assumed to be + // the first dimension. + int input_batches = 1; + if (input->dims->size >= 1 && + input->dims->data[0] * single_input_size == input->bytes) { + input_batches = input->dims->data[0]; + } + + // All inputs must have the same number of batches. + if (batches == 0) { + batches = input_batches; + } else { + RETURN_IF_NOT_EQ(batches, input_batches); + } + + RETURN_IF_NOT_EQ(batches * single_input_size, input->bytes); + } + + // |batches| == 0 means there were no inputs. Treat that as 1 batch, because + // we always want to run the model at least once. + if (batches == 0) { + batches = 1; + } + + // Validate outputs. + RETURN_IF_NOT_EQ(executable_layers_info->NumOutputLayers(), + node->outputs->size); + for (int i = 0; i < executable_layers_info->NumOutputLayers(); ++i) { + TfLiteTensor* output = GetOutput(context, node, i); + ASSIGN_OR_RETURN(const int size_of_data_type, SizeOfDataType(output->type)); + const int single_output_size = + executable_layers_info->OutputLayerSize(i) * size_of_data_type; + + // Data types must match. + RETURN_IF_ERROR(ValidateDataType( + output->type, executable_layers_info->OutputLayer(i)->data_type(), + executable_layers_info->OutputLayer(i))); + + // If there's a batch dimension on the output, set it to the correct size. + // Note that this has to be done even for batches == 1, in case the tensor + // has to be resized down. + if (output->dims->size >= 1 && + output->dims->data[0] * single_output_size == output->bytes) { + if (batches != output->dims->data[0]) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(output->dims->size); + output_size->data[0] = batches; + for (int dim = 1; dim < output->dims->size; ++dim) { + output_size->data[dim] = output->dims->data[dim]; + } + context->ResizeTensor(context, output, output_size); + } + } else { + RETURN_IF_NOT_EQ(batches, 1); + } + + RETURN_IF_NOT_EQ(batches * single_output_size, output->bytes); + } + + user_data->SetBatches(batches); + return util::OkStatus(); +} + +} // namespace + +CustomOpUserData::~CustomOpUserData() { + if (inputs_) { + TfLiteIntArrayFree(inputs_); + inputs_ = nullptr; + } +} + +const std::string& CustomOpUserData::GetSessionName() const { + return session_name_; +} + +bool CustomOpUserData::GetShouldPopulateCache() const { + return should_populate_cache_; +} + +void CustomOpUserData::SetShouldPopulateCache(bool should_populate_cache) { + should_populate_cache_ = should_populate_cache; +} + +TfLiteIntArray* CustomOpUserData::GetInputs() const { return inputs_; } + +TfLiteIntArray* CustomOpUserData::GetInputs(TfLiteNode* node) const { + if (inputs_) { + return inputs_; + } else { + CHECK_NE(node, nullptr); + return node->inputs; + } +} + +const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { + CustomOpUserData* user_data = + reinterpret_cast(node->user_data); + + return &context->tensors[user_data->GetInputs(node)->data[index]]; +} + +TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, int index) { + return &context->tensors[node->outputs->data[index]]; +} + +TfLiteStatus CustomOpPrepare(TfLiteContext* context, TfLiteNode* node) { + if (!node->user_data) { + context->ReportError(context, "Failed to prepare a custom op."); + return kTfLiteError; + } + const auto* user_data = reinterpret_cast(node->user_data); + const auto* executable_layers_info = user_data->GetExecutableLayersInfo(); + CHECK_NE(executable_layers_info, nullptr); + + util::Status status = + ValidateInputsAndOutputs(context, node, executable_layers_info); + if (!status.ok()) { + context->ReportError(context, status.ToString().c_str()); + return kTfLiteError; + } + return kTfLiteOk; +} + +void CustomOpFree(TfLiteContext* context, void* buffer) { + // TODO: Unregister executables, once that is implemented. + delete reinterpret_cast(buffer); +} + +util::Status ReFormatOutputs(TfLiteTensor* output, int output_tensor_offset, + int output_tensor_size, + const api::OutputLayerInformation* output_layer, + const unsigned char* output_data) { + // Although we have 8-bit classifier now, the following is kept for backwards + // compatibility with executables which were generated the old way. + if (output->type == kTfLiteUInt8 && IsFloat32ClassifierLayer(output_layer)) { + const float* tpu_output = reinterpret_cast(output_data); + for (int j = 0; j < output_tensor_size; ++j) { + int quantized_int = static_cast( + (tpu_output[j] / output->params.scale + output->params.zero_point)); + output->data.uint8[j + output_tensor_offset] = + std::min(std::max(quantized_int, 0), 255); + } + } else if (output->type == kTfLiteUInt8 && + IsUint16ClassifierLayer(output_layer)) { + const int16_t* tpu_output = reinterpret_cast(output_data); + for (int j = 0; j < output_tensor_size; ++j) { + output->data.uint8[j + output_tensor_offset] = + std::min(std::max(static_cast(tpu_output[j]), 0), 255); + } + } else { + memcpy(output->data.uint8 + output_tensor_offset, output_data, + output_tensor_size); + } + + return util::OkStatus(); +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/custom_op.h b/tflite/custom_op.h new file mode 100644 index 0000000..b1f946c --- /dev/null +++ b/tflite/custom_op.h @@ -0,0 +1,95 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_CUSTOM_OP_H_ +#define DARWINN_TFLITE_CUSTOM_OP_H_ + +#include "driver/package_registry.h" +#include "tflite/custom_op_data.h" +#include "tensorflow/lite/context.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// This structure encapsulates the data needed to run DarwiNN executables +// after they have been registered with the DarwiNN driver. +class CustomOpUserData { + public: + virtual ~CustomOpUserData(); + + // Session name determines which hardware/service to use, as well as + // management of cache shared by all models running on the same hardware. + // TODO: add session manager to map between name and + // settings/status/control. + const std::string& GetSessionName() const; + + const driver::ExecutableLayersInfo* GetExecutableLayersInfo() const { + return executable_layers_info_; + } + + TfLiteIntArray* GetInputs() const; + TfLiteIntArray* GetInputs(TfLiteNode* node) const; + void SetInputs(TfLiteIntArray* inputs) { inputs_ = inputs; } + + bool GetShouldPopulateCache() const; + void SetShouldPopulateCache(bool should_populate_cache); + + int GetBatches() const { return batches_; } + void SetBatches(int batches) { batches_ = batches; } + + protected: + CustomOpUserData() = default; + + std::string session_name_; + bool should_populate_cache_{true}; + + int batches_{1}; + + // Pointer to the layer info of the executable binary; + const driver::ExecutableLayersInfo* executable_layers_info_{nullptr}; + + // When we use the custom-op implementation to run a delegate op, we can't + // look at the node's inputs to find all the input activation tensors to this + // delegate op (since the node's inputs include all the bias/parameter + // tensors as well). Instead we will look at this array. + TfLiteIntArray* inputs_{nullptr}; +}; + +// Returns input tensor from node. +const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index); + +// Returns output tensor from node. +TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, int index); + +// Re-format output data. +util::Status ReFormatOutputs(TfLiteTensor* output, int output_tensor_offset, + int output_tensor_size, + const api::OutputLayerInformation* output_layer, + const unsigned char* output_data); + +// Prepares custom-op for operation. +TfLiteStatus CustomOpPrepare(TfLiteContext* context, TfLiteNode* node); + +// Deallocates the custom-op data object (CustomOpUserData). The lifetime of +// this object is managed by the TfLite interpreter, which calls this function +// to deallocate this object. +void CustomOpFree(TfLiteContext* context, void* buffer); + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_CUSTOM_OP_H_ diff --git a/tflite/custom_op_data.cc b/tflite/custom_op_data.cc new file mode 100644 index 0000000..b5b73de --- /dev/null +++ b/tflite/custom_op_data.cc @@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/custom_op_data.h" + +#include "port/ptr_util.h" +#include "port/status_macros.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +namespace { + +// CustomOpData struct will be serialized as a Flexbuffer map with the +// following keys. +// "1" ---> integer (version) +// "2" ---> string (DEPRECATED; chip_name). +// "3" ---> string (DEPRECATED; serialized parameter-caching executable) +// "4" ---> string (serialized executable) + +static const char kKeyVersion[] = "1"; +// DEPRECATED (Don't reuse key for something else). +// static const char kKeyChipName[] = "2"; +static const char kKeyParameterCachingExecutable[] = "3"; +static const char kKeyExecutable[] = "4"; +static const char kExecutionPreference[] = "5"; + +} // namespace + +std::unique_ptr SerializeCustomOpData( + const CustomOpData& custom_op_data) { + auto builder = gtl::MakeUnique(); + size_t map_start = builder->StartMap(); + builder->Int(kKeyVersion, custom_op_data.version); + builder->Key(kKeyExecutable); + builder->String(custom_op_data.executable.data, + custom_op_data.executable.length); + builder->Int(kExecutionPreference, custom_op_data.execution_preference); + builder->EndMap(map_start); + builder->Finish(); + return builder; +} + +std::unique_ptr DeserializeCustomOpData(const uint8_t* buffer, + size_t length) { + if (!buffer || !length) { + LOG(ERROR) << "Failed to deserialize into CustomOpData object; " + << " buffer was " << (buffer ? "non-null" : "null") + << ", length was " << length << " bytes"; + return nullptr; + } + auto flexbuffer_map = flexbuffers::GetRoot(buffer, length).AsMap(); + // TODO Remove this check once the field is removed. + CHECK(flexbuffer_map[kKeyParameterCachingExecutable].IsNull()); + auto custom_op_data = gtl::MakeUnique(); + custom_op_data->version = flexbuffer_map[kKeyVersion].AsInt32(); + flexbuffers::String executable = flexbuffer_map[kKeyExecutable].AsString(); + custom_op_data->executable = {executable.c_str(), executable.length()}; + if (!flexbuffer_map[kExecutionPreference].IsNull()) { + custom_op_data->execution_preference = + flexbuffer_map[kExecutionPreference].AsInt32(); + } + return custom_op_data; +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/custom_op_data.h b/tflite/custom_op_data.h new file mode 100644 index 0000000..347a0f6 --- /dev/null +++ b/tflite/custom_op_data.h @@ -0,0 +1,65 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_CUSTOM_OP_DATA_H_ +#define DARWINN_TFLITE_CUSTOM_OP_DATA_H_ + +#include + +#include "flatbuffers/flexbuffers.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// Version 0. +static const int32_t kCustomOpDataVersion = 0; + +struct WrappedBuffer { + const char* data; + size_t length; +}; + +struct CustomOpData { + int32_t version; + // TODO Remove this field and update references in + // custom_op_data.cc. + WrappedBuffer parameter_caching_executable; + WrappedBuffer executable; + + // Execution preference code (currently used by NNAPI only). For more + // information, please see PreferenceCode in: + // https://cs.corp.google.com/android/frameworks/ml/nn/runtime/include/NeuralNetworks.h + // -1 results in NNAPI's default value (FAST_SINGLE_ANSWER=1), it is used as + // default here in order to help identify older custom ops or ones that are + // created in a code path that does not set execution preference. + int32_t execution_preference = -1; +}; + +// Converts the input CustomOpData object into a flexbuffers object that stores +// it in a serializable form. +// This function expects the caller to provide a valid CustomOpData object. +std::unique_ptr SerializeCustomOpData( + const CustomOpData& custom_op_data); + +// Converts the input buffer into an in-memory CustomOpData object. +// Returns nullptr if buffer is null or if length is zero. +std::unique_ptr DeserializeCustomOpData(const uint8_t* buffer, + size_t length); + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_CUSTOM_OP_DATA_H_ diff --git a/tflite/custom_op_direct.cc b/tflite/custom_op_direct.cc new file mode 100644 index 0000000..0c38106 --- /dev/null +++ b/tflite/custom_op_direct.cc @@ -0,0 +1,160 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "api/driver.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/stringprintf.h" +#include "tflite/custom_op.h" +#include "tflite/custom_op_data.h" +#include "tflite/custom_op_user_data_direct.h" +#include "tflite/edgetpu_context_direct.h" +#include "tflite/edgetpu_manager_direct.h" +#include "tflite/public/edgetpu.h" +#include "tensorflow/lite/context.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +namespace { + +using edgetpu::EdgeTpuManager; + +// Various initializations steps for a DarwiNN custom op. +void* CustomOpInit(TfLiteContext* context, const char* buffer, size_t length) { + // Create new operator-specific user data. + // Note this data is different from interpreter-specific data recorded in + // context->GetExternalContext, which is probably not set yet when + // this function is called. + return new CustomOpUserDataDirect(reinterpret_cast(buffer), + length); +} + +// Returns either the associated TPU context. +EdgeTpuContextDirect* GetTpuContext(TfLiteContext* context) { + // Down-cast from TfLiteExternalContext* to EdgeTpuContextDirect* + return static_cast( + context->GetExternalContext(context, kTfLiteEdgeTpuContext)); +} + +// This function is called only when Interpreter believes it's needed, when +// within the call to Intrepreter::AllocateTensor. +TfLiteStatus CustomOpPrepareDirect(TfLiteContext* context, TfLiteNode* node) { + CustomOpUserDataDirect* user_data = + reinterpret_cast(node->user_data); + + if (!user_data) { + context->ReportError(context, "Null custom op data."); + return kTfLiteError; + } + + auto* interpreter_context = GetTpuContext(context); + if (!interpreter_context) { + context->ReportError(context, "Failed to retrieve TPU context."); + return kTfLiteError; + } + + // Binds this custom op instance with a particular driver instance. + // It actually registers the model with the driver specified in interpreter + // context. + auto result = user_data->SetDriver( + interpreter_context->GetDriverWrapper()->GetDriver()); + if (!result.ok()) { + context->ReportError(context, "Failed to prepare for TPU. %s", + result.ToString().c_str()); + return kTfLiteError; + } + + return CustomOpPrepare(context, node); +} + +// De-allocates the per-node-and-Interpreter custom data. +void CustomOpFreeDirect(TfLiteContext* context, void* buffer) { + // TODO: Remove the whole function after the new cache mechanism is + // ready. Use CustomOpFree instead. + CustomOpUserDataDirect* user_data = + reinterpret_cast(buffer); + + if (!user_data) { + context->ReportError(context, "Null custom op data."); + return; + } + + // Deleting user_data un-registers the model from the driver, if it has ever + // been registered. + delete user_data; +} + +TfLiteStatus CustomOpInvoke(TfLiteContext* context, TfLiteNode* node) { + CustomOpUserDataDirect* user_data = + reinterpret_cast(node->user_data); + + if (!user_data) { + context->ReportError(context, "Null custom op data."); + return kTfLiteError; + } + + auto* interpreter_context = GetTpuContext(context); + if (!interpreter_context) { + context->ReportError(context, "Failed to retrieve TPU context."); + return kTfLiteError; + } + + auto result = + interpreter_context->GetDriverWrapper()->InvokeExecutable(context, node); + if (!result.ok()) { + context->ReportError(context, StringPrintf("Failed to execute request. %s", + result.error_message().c_str()) + .c_str()); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +namespace edgetpu { + +TfLiteRegistration* RegisterCustomOp() { + static TfLiteRegistration registration = { + platforms::darwinn::tflite::CustomOpInit, + platforms::darwinn::tflite::CustomOpFreeDirect, + platforms::darwinn::tflite::CustomOpPrepareDirect, + platforms::darwinn::tflite::CustomOpInvoke, + }; + return ®istration; +} + +EdgeTpuManager* EdgeTpuManager::GetSingleton() { + return platforms::darwinn::tflite::EdgeTpuManagerDirect::GetSingleton(); +} + +std::ostream& operator<<(std::ostream& out, edgetpu::DeviceType device_type) { + out << platforms::darwinn::tflite::EdgeTpuDriverWrapper::GetDeviceTypeName( + device_type); + return out; +} + +} // namespace edgetpu diff --git a/tflite/custom_op_user_data_direct.cc b/tflite/custom_op_user_data_direct.cc new file mode 100644 index 0000000..4dd4c57 --- /dev/null +++ b/tflite/custom_op_user_data_direct.cc @@ -0,0 +1,96 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/custom_op_user_data_direct.h" + +#include "driver/package_registry.h" +#include "port/logging.h" +#include "port/stringprintf.h" +#include "tflite/custom_op_data.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +CustomOpUserDataDirect::CustomOpUserDataDirect(const uint8_t* buffer, + size_t length) + : raw_model_data_(DeserializeCustomOpData(buffer, length)) {} + +CustomOpUserDataDirect::~CustomOpUserDataDirect() { + (void)UnregisterExecutables(); +} + +util::Status CustomOpUserDataDirect::SetDriver(api::Driver* driver) { + if (!driver) { + return util::InvalidArgumentError("Cannot be assigned to nullptr."); + } + + if (driver_) { + if (driver == driver_) { + // It is okay to set to the same driver instance. + // Prepare could be called multiple times to the same set of operators. + return util::Status(); // OK. + } else { + return util::FailedPreconditionError( + "Custom op already assigned to a different TPU."); + } + } + driver_ = driver; + + if (!raw_model_data_) { + return util::FailedPreconditionError("Missing raw model data."); + } + + // Register the executable. + ASSIGN_OR_RETURN(executable_, driver_->RegisterExecutableSerialized( + raw_model_data_->executable.data, + raw_model_data_->executable.length)); + + // Gets the executable layer info from the executable binary. + // TODO: Merge the ExecutableLayersInfo and api::PackageReference + ASSIGN_OR_RETURN( + auto executable_layers_info_unique_ptr, + driver::PackageRegistry::GetMainExecutableLayersInfoFromBinary( + raw_model_data_->executable.data, + raw_model_data_->executable.length)); + + // The executable layer info will stay alive till it's deleted in unregister. + executable_layers_info_ = executable_layers_info_unique_ptr.release(); + + raw_model_data_.reset(); + + return util::Status(); // OK. +} + +util::Status CustomOpUserDataDirect::UnregisterExecutables() { + if (!driver_) { + return util::Status(); // OK. + } + + if (executable_) { + (void)driver_->UnregisterExecutable(executable_); + executable_ = nullptr; + } + + if (executable_layers_info_) { + delete executable_layers_info_; + executable_layers_info_ = nullptr; + } + + return util::Status(); // OK. +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/custom_op_user_data_direct.h b/tflite/custom_op_user_data_direct.h new file mode 100644 index 0000000..5b9b989 --- /dev/null +++ b/tflite/custom_op_user_data_direct.h @@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_CUSTOM_OP_USER_DATA_DIRECT_H_ +#define DARWINN_TFLITE_CUSTOM_OP_USER_DATA_DIRECT_H_ + +#include + +#include "api/driver.h" +#include "driver/package_registry.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "tflite/custom_op.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// Node-and-Interpreter-specific custom data. This is allocated by OpInit, +// which is called at Interpreter creation time, per custom op node. +// This object is non-threadsafe, for Interpreter itself is single threaded. +// Pointer to this object instance is needed to de-allocate resources +// associated with this custom op instance in the interpreter context, as +// pointer to upper level node instance is not passed to +// TfLiteRegistration::free callback. +class CustomOpUserDataDirect : public CustomOpUserData { + public: + CustomOpUserDataDirect(const uint8_t* buffer, size_t length); + + ~CustomOpUserDataDirect(); + + // Binds to a driver instance, and registers executables with this driver. + util::Status SetDriver(api::Driver* driver); + + // Returns the reference to the executable binary. + const api::PackageReference* GetExecutable() const { return executable_; } + + private: + // Unregisters executables with the associated driver. + util::Status UnregisterExecutables(); + + // Raw data parsed from tflite model file. + std::unique_ptr raw_model_data_; + + // Pointer to the driver instance associated with this custom op node. + // Note that a driver instance can be shared by many custom op nodes, and + // execution of all of these nodes would be serialized in + // EdgeTpuContextDirect::InvokeExecutable. Thread safty is guaranteed by + // Driver. + api::Driver* driver_{nullptr}; + + // Pointer to the reference of the executable binary; + const api::PackageReference* executable_{nullptr}; +}; + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_CUSTOM_OP_USER_DATA_DIRECT_H_ diff --git a/tflite/edgetpu_c.cc b/tflite/edgetpu_c.cc new file mode 100644 index 0000000..087dc43 --- /dev/null +++ b/tflite/edgetpu_c.cc @@ -0,0 +1,107 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/public/edgetpu_c.h" + +#include + +#include "port/logging.h" +#include "tflite/edgetpu_delegate_for_custom_op.h" +#include "tflite/public/edgetpu.h" + +using edgetpu::DeviceType; +using edgetpu::EdgeTpuContext; +using edgetpu::EdgeTpuManager; +using DeviceOptions = EdgeTpuManager::DeviceOptions; + +extern "C" { + +struct edgetpu_device* edgetpu_list_devices(size_t* num_devices) { + CHECK(num_devices); + + auto records = EdgeTpuManager::GetSingleton()->EnumerateEdgeTpu(); + if (records.empty()) { + *num_devices = 0; + return nullptr; + } + + const auto devs_size = sizeof(edgetpu_device) * records.size(); + + size_t size = devs_size; + for (const auto& record : records) size += record.path.size() + 1; + + char* memory = new char[size]; + edgetpu_device* devs = reinterpret_cast(memory); + char* paths = memory + devs_size; + + int i = 0; + for (const auto& record : records) { + edgetpu_device* dev = &devs[i++]; + dev->type = static_cast(record.type); + dev->path = paths; + + const auto len = record.path.size() + 1; + std::memcpy(paths, record.path.c_str(), len); + paths += len; + } + + *num_devices = records.size(); + return devs; +} + +void edgetpu_free_devices(struct edgetpu_device* dev) { + delete[] reinterpret_cast(dev); +} + +TfLiteDelegate* edgetpu_create_delegate(enum edgetpu_device_type type, + const char* name, + const struct edgetpu_option* options, + size_t num_options) { + auto* manager = EdgeTpuManager::GetSingleton(); + const auto device_type = static_cast(type); + + std::shared_ptr context; + if (num_options > 0) { + CHECK(options); + CHECK(name); + EdgeTpuManager::DeviceOptions device_options; + for (size_t i = 0; i < num_options; ++i) { + const edgetpu_option* option = &options[i]; + device_options.insert({option->name, option->value}); + } + context = manager->OpenDevice(device_type, name, device_options); + } else { + context = (name == nullptr) ? manager->OpenDevice(device_type) + : manager->OpenDevice(device_type, name); + } + + if (!context) return nullptr; + return platforms::darwinn::tflite::CreateEdgeTpuDelegateForCustomOp(context); +} + +void edgetpu_free_delegate(TfLiteDelegate* delegate) { + platforms::darwinn::tflite::FreeEdgeTpuDelegateForCustomOp(delegate); +} + +void edgetpu_verbosity(int verbosity) { + EdgeTpuManager::GetSingleton()->SetVerbosity(verbosity); +} + +const char* edgetpu_version() { + static auto* version = + new std::string(EdgeTpuManager::GetSingleton()->Version()); + return version->c_str(); +} + +} // extern "C" diff --git a/tflite/edgetpu_context_direct.cc b/tflite/edgetpu_context_direct.cc new file mode 100644 index 0000000..f6dca05 --- /dev/null +++ b/tflite/edgetpu_context_direct.cc @@ -0,0 +1,434 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/edgetpu_context_direct.h" + +#include "api/driver_factory.h" +#include "api/driver_options_generated.h" +#include "port/logging.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "tflite/custom_op_user_data_direct.h" +#include "tflite/edgetpu_manager_direct.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +namespace { + +using edgetpu::EdgeTpuContext; + +// Set default throttled usb performance based on THROTTLE_EDGE_TPU +// THROTTLE_EDGE_TPU = undefined/0: Max; 1: High; 2: Med; 3: Low; others: High +api::PerformanceExpectation DefaultThrottledUsbPerformance( + edgetpu::DeviceType device_type) { + api::PerformanceExpectation performance = api::PerformanceExpectation_Max; +#if defined(THROTTLE_EDGE_TPU) && THROTTLE_EDGE_TPU != 0 + if (device_type == edgetpu::DeviceType::kApexUsb) { + performance = api::PerformanceExpectation_High; +#if THROTTLE_EDGE_TPU == 2 + performance = api::PerformanceExpectation_Medium; +#elif THROTTLE_EDGE_TPU == 3 + performance = api::PerformanceExpectation_Low; +#endif + } +#endif + return performance; +} + +// Sets performance exectation in a driver option builder. +util::Status ParsePerformanceExpectationWithDefaultMax( + edgetpu::DeviceType device_type, + const std::unordered_map& options, + api::DriverOptionsBuilder* driver_option_builder) { + const auto& it = options.find("Performance"); + api::PerformanceExpectation performance = api::PerformanceExpectation_Max; + if (it == options.end()) { + performance = DefaultThrottledUsbPerformance(device_type); + if (performance != api::PerformanceExpectation_Max) { + VLOG(2) << "Performance expectation: " + << api::EnumNamePerformanceExpectation(performance) + << " when USB connected EdgeTpu is throttled"; + } else { + VLOG(2) << "Performance expectation: Max (default)"; + } + } else if (it->second == "Low") { + VLOG(2) << "Performance expectation: Low"; + performance = api::PerformanceExpectation_Low; + } else if (it->second == "Medium") { + VLOG(2) << "Performance expectation: Medium"; + performance = api::PerformanceExpectation_Medium; + } else if (it->second == "High") { + VLOG(2) << "Performance expectation: High"; + performance = api::PerformanceExpectation_High; + } else if (it->second == "Max") { + performance = DefaultThrottledUsbPerformance(device_type); + if (performance != api::PerformanceExpectation_Max) { + VLOG(2) << "Performance expectation level Max is not supported when " + "USB connected EdgeTpu is throttled. Drop to " + << api::EnumNamePerformanceExpectation(performance) << "."; + } else { + VLOG(2) << "Performance expectation: Max"; + } + } else { + return util::InvalidArgumentError("Invalid performance setting."); + } + + driver_option_builder->add_performance_expectation(performance); + + return util::OkStatus(); +} + +// Sets USB options in a driver USB option builder. +util::Status ParseUsbOptions( + const std::unordered_map& options, + api::DriverUsbOptionsBuilder* usb_option_builder) { + // Retrieve USB always DFU settings. + // Setting this to True would force the driver to perform DFU at driver open. + { + bool always_dfu = false; + const auto& it = options.find("Usb.AlwaysDfu"); + if (it == options.end()) { + VLOG(2) << "USB always DFU: False (default)"; + always_dfu = false; + } else if (it->second == "True") { + VLOG(2) << "USB always DFU: True"; + always_dfu = true; + } else if (it->second == "False") { + VLOG(2) << "USB always DFU: False"; + always_dfu = false; + } else { + return util::InvalidArgumentError("Invalid USB setting."); + } + + usb_option_builder->add_always_dfu(always_dfu); + } + + // Retrieve USB bulk-in queue length limit settings. + // Setting this to something large, like 32, would give better performance for + // models with many output layers. + { + int bulk_in_queue_capacity = 0; + const auto& it = options.find("Usb.MaxBulkInQueueLength"); + if (it == options.end()) { + VLOG(2) << "USB bulk-in queue capacity: default"; + } else { + // TODO: change to ABSL SimpleAtoi after it's available. + std::istringstream ss(it->second); + ss >> bulk_in_queue_capacity; + if (ss.fail() || !ss.eof()) { + return util::InvalidArgumentError( + "Converting string argument to integer failed."); + } + + if (bulk_in_queue_capacity == 0) { + VLOG(2) << "USB queued bulk-in requests disabled"; + usb_option_builder->add_enable_queued_bulk_in_requests(false); + usb_option_builder->add_has_enable_queued_bulk_in_requests(true); + } else if ((bulk_in_queue_capacity < 0) || + (bulk_in_queue_capacity > 256)) { + return util::InvalidArgumentError( + "bulk-in queue capacity must be in [0, 256]."); + } else { + VLOG(2) << "USB bulk-in queue capacity: " << bulk_in_queue_capacity; + usb_option_builder->add_bulk_in_queue_capacity(bulk_in_queue_capacity); + usb_option_builder->add_has_bulk_in_queue_capacity(true); + } + } + } + + return util::OkStatus(); +} + +} // namespace + +// EdgeTpuDriverWrapper +const char* EdgeTpuDriverWrapper::STATUS_IS_READY = "IsReady"; +const char* EdgeTpuDriverWrapper::STATUS_EXCLUSIVE_OWNERSHIP = + "ExclusiveOwnership"; + +EdgeTpuDriverWrapper::EdgeTpuDriverWrapper( + std::unique_ptr driver, + const EdgeTpuManager::DeviceEnumerationRecord& enum_record, + const EdgeTpuManager::DeviceOptions options, bool exclusive_ownership) + : use_count_(0), + is_ready_(true), + is_exclusively_owned_(exclusive_ownership), + driver_(std::move(driver)), + enum_record_(enum_record), + options_(options) { + VLOG(4) << "Opening device at " << enum_record_.path; +} + +EdgeTpuDriverWrapper::~EdgeTpuDriverWrapper() { + StdMutexLock lock(&mutex_); + + VLOG(4) << "Closing Edge TPU device at " << enum_record_.path; + + (void)driver_->Close(api::Driver::ClosingMode::kGraceful); + driver_.reset(); + is_ready_ = false; +} + +api::Driver* EdgeTpuDriverWrapper::GetDriver() const { + StdMutexLock lock(&mutex_); + + return driver_.get(); +} + +util::Status EdgeTpuDriverWrapper::InvokeExecutable(TfLiteContext* context, + TfLiteNode* node) { + StdMutexLock lock(&mutex_); + + CustomOpUserDataDirect* user_data = + reinterpret_cast(node->user_data); + + if (!driver_ || !is_ready_) { + return util::FailedPreconditionError("Edge TPU is not ready."); + } + auto executable = user_data->GetExecutable(); + const auto batches = user_data->GetBatches(); + + ASSIGN_OR_RETURN(std::shared_ptr request, + driver_->CreateRequest(executable)); + + // Attach inputs to the request. + for (int i = 0; i < executable->NumInputLayers(); ++i) { + const auto* input = GetInput(context, node, i); + const auto single_input_size = executable->InputLayer(i)->ActualSizeBytes(); + if (input->buffer_handle != kTfLiteNullBufferHandle && batches > 1) { + // TODO: How to handle batches > 1? + return util::FailedPreconditionError("Too many batches for dma-buf."); + } + for (int batch = 0; batch < batches; ++batch) { + Buffer input_buffer = + input->buffer_handle == kTfLiteNullBufferHandle + ? Buffer(input->data.raw + batch * single_input_size, + single_input_size) + : Buffer(input->buffer_handle, single_input_size, false); + RETURN_IF_ERROR( + request->AddInput(executable->InputLayerName(i), input_buffer)); + } + } + + // Attach outputs to the request. + std::vector output_buffers; + output_buffers.reserve(executable->NumOutputLayers() * batches); + for (int i = 0; i < executable->NumOutputLayers(); ++i) { + for (int batch = 0; batch < batches; ++batch) { + Buffer output_buffer = + driver_->MakeBuffer(executable->OutputLayer(i)->ActualSizeBytes()); + output_buffers.push_back(output_buffer); + RETURN_IF_ERROR( + request->AddOutput(executable->OutputLayerName(i), output_buffer)); + } + } + + // Submit. + RETURN_IF_ERROR(driver_->Execute(std::move(request))); + + // Relayout tpu outputs to tflite outputs. + for (int i = 0; i < executable->NumOutputLayers(); ++i) { + auto* output = GetOutput(context, node, i); + int output_size = output->bytes / batches; + for (int batch = 0; batch < batches; ++batch) { + RETURN_IF_ERROR(ReFormatOutputs(output, batch * output_size, output_size, + executable->OutputLayer(i), + output_buffers[i].ptr())); + } + } + + return util::OkStatus(); +} + +const EdgeTpuManager::DeviceEnumerationRecord& +EdgeTpuDriverWrapper::GetDeviceEnumRecord() const { + StdMutexLock lock(&mutex_); + + return enum_record_; +} + +EdgeTpuManager::DeviceOptions EdgeTpuDriverWrapper::GetDeviceOptions() const { + StdMutexLock lock(&mutex_); + + EdgeTpuManager::DeviceOptions status(options_); + if (is_ready_) { + status.insert({STATUS_IS_READY, std::string()}); + } + if (is_exclusively_owned_) { + status.insert({STATUS_EXCLUSIVE_OWNERSHIP, std::string()}); + } + return status; +} + +util::Status EdgeTpuDriverWrapper::AddRef() { + StdMutexLock lock(&mutex_); + + // TODO: Add check for wrap around. + ++use_count_; + + return util::OkStatus(); +} + +int EdgeTpuDriverWrapper::Release() { + StdMutexLock lock(&mutex_); + + // TODO: Add check for wrap around. + --use_count_; + + return use_count_; +} + +bool EdgeTpuDriverWrapper::IsReady() const { + StdMutexLock lock(&mutex_); + + return is_ready_; +} + +bool EdgeTpuDriverWrapper::IsExclusivelyOwned() const { + StdMutexLock lock(&mutex_); + + return is_exclusively_owned_; +} + +std::unique_ptr EdgeTpuDriverWrapper::MakeOpenedDriver( + edgetpu::DeviceType device_type, const std::string& device_path, + const std::unordered_map& options) { + auto factory = api::DriverFactory::GetOrCreate(); + if (!factory) { + VLOG(1) << "Failed to create driver factory."; + return nullptr; + } + + flatbuffers::FlatBufferBuilder flatbuffer_builder; + + // TODO: Remove this empty string. + // Note that flat buffers require allocation to happen before the object's + // parent, so this string has to be allocated before the option. + auto empty_public_key = flatbuffer_builder.CreateString(""); + + api::DriverUsbOptionsBuilder usb_option_builder(flatbuffer_builder); + auto parse_result = ParseUsbOptions(options, &usb_option_builder); + if (!parse_result.ok()) { + VLOG(1) << parse_result; + return nullptr; + } + auto usb_option = usb_option_builder.Finish(); + + api::DriverOptionsBuilder driver_option_builder(flatbuffer_builder); + driver_option_builder.add_public_key(empty_public_key); + driver_option_builder.add_verbosity(-1); + + parse_result = ParsePerformanceExpectationWithDefaultMax( + device_type, options, &driver_option_builder); + if (!parse_result.ok()) { + VLOG(1) << parse_result; + return nullptr; + } + + driver_option_builder.add_usb(usb_option); + api::Chip chip; + api::Device::Type type; + + int i_type = static_cast(device_type); + switch (i_type) { + case static_cast(edgetpu::DeviceType::kApexPci): + chip = api::Chip::kBeagle; + type = api::Device::Type::PCI; + break; + case static_cast(edgetpu::DeviceType::kApexUsb): + chip = api::Chip::kBeagle; + type = api::Device::Type::USB; + break; + case static_cast(DeviceTypeExtended::kApexReference): + chip = api::Chip::kBeagle; + type = api::Device::Type::REFERENCE; + break; + + default: + VLOG(1) << "Unsupported device type."; + return nullptr; + } + + flatbuffer_builder.Finish(driver_option_builder.Finish()); + + auto driver_option = api::Driver::Options( + flatbuffer_builder.GetBufferPointer(), + flatbuffer_builder.GetBufferPointer() + flatbuffer_builder.GetSize()); + + auto result = factory->CreateDriver({chip, type, device_path}, driver_option); + + if (!result.ok()) { + VLOG(1) << StringPrintf("Failed to create driver [%s] at [%s]: ", + GetDeviceTypeName(device_type), + device_path.c_str()) << result.status().ToString(); + return nullptr; + } + + auto temp_driver = std::move(result).ValueOrDie(); + auto open_status = temp_driver->Open(); + + if (!open_status.ok()) { + VLOG(1) << StringPrintf("Failed to open device [%s] at [%s]: ", + GetDeviceTypeName(device_type), + device_path.c_str()) << open_status.ToString(); + return nullptr; + } + + return temp_driver; +} + +const char* EdgeTpuDriverWrapper::GetDeviceTypeName( + edgetpu::DeviceType device_type) { + int type = static_cast(device_type); + + switch (type) { + case static_cast(edgetpu::DeviceType::kApexPci): + return "Apex (PCIe)"; + case static_cast(edgetpu::DeviceType::kApexUsb): + return "Apex (USB)"; + case static_cast(DeviceTypeExtended::kApexReference): + return "Apex (Reference)"; + default: + // Note that many internal device types do not have external names yet, so + // they cannot be named here. + return "Unknown"; + } +} + +// EdgeTpuContextDirect + +EdgeTpuContextDirect::EdgeTpuContextDirect(EdgeTpuDriverWrapper* driver_wrapper) + : driver_wrapper_(driver_wrapper) { + // We're not handling notification sent to TfLiteExternalContext::Reresh + this->Refresh = nullptr; + + CHECK_OK(driver_wrapper_->AddRef()); +} + +EdgeTpuContextDirect::~EdgeTpuContextDirect() { + EdgeTpuManagerDirect::GetSingleton()->ReleaseEdgeTpuContext(driver_wrapper_); + driver_wrapper_ = nullptr; +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +namespace edgetpu { + +edgetpu::EdgeTpuContext::~EdgeTpuContext() = default; + +} // namespace edgetpu diff --git a/tflite/edgetpu_context_direct.h b/tflite/edgetpu_context_direct.h new file mode 100644 index 0000000..eb6d3a4 --- /dev/null +++ b/tflite/edgetpu_context_direct.h @@ -0,0 +1,139 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_EDGETPU_CONTEXT_DIRECT_H_ +#define DARWINN_TFLITE_EDGETPU_CONTEXT_DIRECT_H_ + +#include // NOLINT + +#include "api/driver.h" +#include "api/package_reference.h" +#include "port/errors.h" +#include "port/ptr_util.h" +#include "port/status.h" +#include "port/status_macros.h" +#include "port/statusor.h" +#include "port/thread_annotations.h" +#include "tflite/custom_op_user_data_direct.h" +#include "tflite/public/edgetpu.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +using edgetpu::DeviceType; +using edgetpu::EdgeTpuContext; +using edgetpu::EdgeTpuManager; + +// Internal-only extension to edgetpu::DeviceType +enum class DeviceTypeExtended { + kExtendedBegin = 1000, + kUnknown = kExtendedBegin + 0, + kApexReference = kExtendedBegin + 1, + kApexAny = kExtendedBegin + 2, +}; + +// Holds opened device through api::Driver interface. +class EdgeTpuDriverWrapper { + public: + // Constructs EdgeTpuDriverWrapper with an opened instance of the driver. + EdgeTpuDriverWrapper( + std::unique_ptr driver, + const EdgeTpuManager::DeviceEnumerationRecord& enum_record, + const EdgeTpuManager::DeviceOptions options, bool exclusive_ownership); + + ~EdgeTpuDriverWrapper(); + + // Returns the pointer to the driver. + api::Driver* GetDriver() const LOCKS_EXCLUDED(mutex_); + + // Synchronously executes executables for this node, with the context object + // locked. + // TODO: optimize locking, so part of pre- and post-processing can + // be done without the context object being locked. + util::Status InvokeExecutable(TfLiteContext* context, TfLiteNode* node) + LOCKS_EXCLUDED(mutex_); + + // Returns constant reference to enumeration record for this device. + const EdgeTpuManager::DeviceEnumerationRecord& GetDeviceEnumRecord() const + LOCKS_EXCLUDED(mutex_); + + // Returns a snapshot of device options and attributes. + EdgeTpuManager::DeviceOptions GetDeviceOptions() const LOCKS_EXCLUDED(mutex_); + + // Intended to be used by #EdgeTpuContextDirect + util::Status AddRef() LOCKS_EXCLUDED(mutex_); + + // Intended to be used by #EdgeTpuManagerDirect + int Release() LOCKS_EXCLUDED(mutex_); + + // Returns true if the device is most likely ready to accept requests. + // When there are fatal errors, including unplugging of an USB device, the + // state of this device would be changed. + bool IsReady() const LOCKS_EXCLUDED(mutex_); + + // Returns true if the device is exclusively owned by an unique_ptr. + bool IsExclusivelyOwned() const LOCKS_EXCLUDED(mutex_); + + // Makes an new api::Driver and opens it, or nullptr on failure. + static std::unique_ptr MakeOpenedDriver( + DeviceType device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options); + + // Returns name in string for device types. + static const char* GetDeviceTypeName(edgetpu::DeviceType device_type); + + private: + static const char* STATUS_IS_READY; + static const char* STATUS_EXCLUSIVE_OWNERSHIP; + + // Serializes access to this device. + mutable std::mutex mutex_; + + int use_count_ GUARDED_BY(mutex_){0}; + bool is_ready_ GUARDED_BY(mutex_){false}; + bool is_exclusively_owned_ GUARDED_BY(mutex_){false}; + std::unique_ptr driver_ GUARDED_BY(mutex_); + const EdgeTpuManager::DeviceEnumerationRecord enum_record_ GUARDED_BY(mutex_); + const EdgeTpuManager::DeviceOptions options_ GUARDED_BY(mutex_); +}; + +class EdgeTpuContextDirect : public EdgeTpuContext { + public: + explicit EdgeTpuContextDirect(EdgeTpuDriverWrapper* driver_wrapper); + + ~EdgeTpuContextDirect(); + + const EdgeTpuManager::DeviceEnumerationRecord& GetDeviceEnumRecord() + const final { + return driver_wrapper_->GetDeviceEnumRecord(); + } + + EdgeTpuManager::DeviceOptions GetDeviceOptions() const final { + return driver_wrapper_->GetDeviceOptions(); + } + + bool IsReady() const final { return driver_wrapper_->IsReady(); } + + EdgeTpuDriverWrapper* GetDriverWrapper() const { return driver_wrapper_; } + + private: + EdgeTpuDriverWrapper* driver_wrapper_{nullptr}; +}; + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_EDGETPU_CONTEXT_DIRECT_H_ diff --git a/tflite/edgetpu_context_factory.cc b/tflite/edgetpu_context_factory.cc new file mode 100644 index 0000000..167d96a --- /dev/null +++ b/tflite/edgetpu_context_factory.cc @@ -0,0 +1,230 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/edgetpu_context_factory.h" + +#include "port/errors.h" +#include "port/logging.h" +#include "port/ptr_util.h" +#include "port/stringprintf.h" +#include "tflite/edgetpu_context_direct.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +const char* EdgeTpuContextFactory::GetDescriptionForDeviceTypeOptions() { + static const std::string desc = + StringPrintf("Type of Edge TPU device. Possible choices are %s | %s | %s", + kDeviceTypeDefault, kDeviceTypeApexUsb, kDeviceTypeApexPci); + + return desc.c_str(); +} + +const char* EdgeTpuContextFactory::GetDescriptionForDevicePathOptions() { + static const char* desc = "Path to Edge TPU device."; + return desc; +} + +const char* +EdgeTpuContextFactory::GetDescriptionForPerformanceExpectationOptions() { + static const char* desc = + "Clock rate settings affecting performance: 0: Low, 1: Medium, " + "2: High (default), 3: Max"; + return desc; +} + +util::StatusOr> +EdgeTpuContextFactory::CreateEdgeTpuContext(const std::string& device_type, + const std::string& device_path, + int performance_expectation) { + auto tpu_manager = edgetpu::EdgeTpuManager::GetSingleton(); + + if (tpu_manager == nullptr) { + // Return a null context when running on NNAPI + return std::unique_ptr(nullptr); + } + + edgetpu::DeviceType device_type_enum; + if (device_type == kDeviceTypeDefault) { + if (device_path != kDevicePathDefault) { + return util::InvalidArgumentError( + StringPrintf("device_path must be %s when device_type is %s", + kDevicePathDefault, kDeviceTypeDefault)); + } + + if (performance_expectation != kPerformanceExpectationDefault) { + return util::InvalidArgumentError(StringPrintf( + "performance_expectation has no effect when device_type is %s", + kDeviceTypeDefault)); + } + + auto result = tpu_manager->NewEdgeTpuContext(); + if (!result) { + return util::NotFoundError("Failed opening default Edge TPU."); + } + return std::move(result); + } else if (device_type == kDeviceTypeApexUsb) { + device_type_enum = edgetpu::DeviceType::kApexUsb; + } else if (device_type == kDeviceTypeApexPci) { + device_type_enum = edgetpu::DeviceType::kApexPci; + } else if (device_type == kDeviceTypeApexReference) { + device_type_enum = + static_cast(DeviceTypeExtended::kApexReference); + } else { + return util::InvalidArgumentError("Unrecognized device type."); + } + + std::string performance_expectation_str; + switch (performance_expectation) { + case 0: + performance_expectation_str = "Low"; + break; + case 1: + performance_expectation_str = "Medium"; + break; + case 2: + performance_expectation_str = "High"; + break; + case 3: + performance_expectation_str = "Max"; + break; + default: + LOG(FATAL) << "Unrecognized performance expectation."; + } + + auto result = tpu_manager->NewEdgeTpuContext( + device_type_enum, device_path, + {{"Performance", performance_expectation_str}}); + if (!result) { + return util::NotFoundError("Failed opening specified Edge TPU."); + } + + return std::move(result); +} + +util::StatusOr> +EdgeTpuContextFactory::OpenEdgeTpuContext(const std::string& device_type, + const std::string& device_path, + int performance_expectation) { + auto tpu_manager = edgetpu::EdgeTpuManager::GetSingleton(); + + if (tpu_manager == nullptr) { + // Return a null context when running on NNAPI + return std::shared_ptr(nullptr); + } + + edgetpu::DeviceType device_type_enum; + if (device_type == kDeviceTypeDefault) { + if (device_path != kDevicePathDefault) { + return util::InvalidArgumentError( + StringPrintf("device_path must be %s when device_type is %s", + kDevicePathDefault, kDeviceTypeDefault)); + } + + if (performance_expectation != kPerformanceExpectationDefault) { + return util::InvalidArgumentError(StringPrintf( + "performance_expectation has no effect when device_type is %s", + kDeviceTypeDefault)); + } + + auto result = tpu_manager->OpenDevice(); + if (!result) { + return util::NotFoundError("Failed opening default Edge TPU."); + } + return std::move(result); + } else if (device_type == kDeviceTypeApexUsb) { + device_type_enum = edgetpu::DeviceType::kApexUsb; + } else if (device_type == kDeviceTypeApexPci) { + device_type_enum = edgetpu::DeviceType::kApexPci; + } else if (device_type == kDeviceTypeApexReference) { + device_type_enum = + static_cast(DeviceTypeExtended::kApexReference); + } else { + return util::InvalidArgumentError("Unrecognized device type."); + } + + std::string performance_expectation_str; + switch (performance_expectation) { + case 0: + performance_expectation_str = "Low"; + break; + case 1: + performance_expectation_str = "Medium"; + break; + case 2: + performance_expectation_str = "High"; + break; + case 3: + performance_expectation_str = "Max"; + break; + default: + LOG(FATAL) << "Unrecognized performance expectation."; + } + + auto result = + tpu_manager->OpenDevice(device_type_enum, device_path, + {{"Performance", performance_expectation_str}}); + if (!result) { + return util::NotFoundError("Failed opening specified Edge TPU."); + } + + return std::move(result); +} + +util::StatusOr> +EdgeTpuContextFactory::EnumerateEdgeTpu(const std::string& device_type) { + auto tpu_manager = edgetpu::EdgeTpuManager::GetSingleton(); + + if (tpu_manager == nullptr) { + return util::FailedPreconditionError("Cannot enumerate NNAPI devices."); + } + + auto devices = tpu_manager->EnumerateEdgeTpu(); + + if (device_type != kDeviceTypeDefault) { + edgetpu::DeviceType device_type_enum; + + if (device_type == kDeviceTypeApexUsb) { + device_type_enum = edgetpu::DeviceType::kApexUsb; + } else if (device_type == kDeviceTypeApexPci) { + device_type_enum = edgetpu::DeviceType::kApexPci; + } else if (device_type == kDeviceTypeApexReference) { + device_type_enum = + static_cast(DeviceTypeExtended::kApexReference); + } else { + return util::InvalidArgumentError("Unrecognized device type."); + } + + // Filter out all devices not of the specified type. + for (auto it = devices.begin(); it != devices.end();) { + if (it->type != device_type_enum) { + it = devices.erase(it); + } else { + ++it; + } + } + } + + if (devices.empty()) { + return util::NotFoundError( + "Failed finding any Edge TPU of specified type."); + } + + return devices; +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/edgetpu_context_factory.h b/tflite/edgetpu_context_factory.h new file mode 100644 index 0000000..c167fa6 --- /dev/null +++ b/tflite/edgetpu_context_factory.h @@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_EDGE_TPU_CONTEXT_FACTORY_H_ +#define DARWINN_TFLITE_EDGE_TPU_CONTEXT_FACTORY_H_ + +#include +#include + +#include "port/statusor.h" +#include "tflite/public/edgetpu.h" +#include "tensorflow/lite/context.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// This class creates EdgeTPU context from various configurations. +// This class wraps edgetpu::EdgeTpuManager::NewEdgeTpuContext and makes it +// simpler to use with command line options. +class EdgeTpuContextFactory { + public: + static constexpr const char* kDeviceTypeDefault = "default"; + static constexpr const char* kDeviceTypeApexUsb = "apex_usb"; + static constexpr const char* kDeviceTypeApexPci = "apex_pci"; + + // Internal-only option, which doesn't show up in + // GetDescriptionForDeviceTypeOptions + static constexpr const char* kDeviceTypeApexReference = "apex_ref"; + + static constexpr const char* kDevicePathDefault = "default"; + + static constexpr int kPerformanceExpectationDefault = 3; + + // Returns description for device type specifier. + // Note that the return string is static to this class. + static const char* GetDescriptionForDeviceTypeOptions(); + + // Returns description for device path specifier. + // Note that the return string is static to this class. + static const char* GetDescriptionForDevicePathOptions(); + + // Returns description for performance specifier. + // Note that the return string is static to this class. + static const char* GetDescriptionForPerformanceExpectationOptions(); + + // Creates an EdgeTpu context holder on success. + static util::StatusOr> + CreateEdgeTpuContext(const std::string& device_type, + const std::string& device_path, + int performance_expectation); + + // Creates an EdgeTpu context holder on success, intended to be shared. + static util::StatusOr> + OpenEdgeTpuContext(const std::string& device_type, + const std::string& device_path, + int performance_expectation); + + // Enumerates Edge TPU devices of the specified type. + static util::StatusOr< + std::vector> + EnumerateEdgeTpu(const std::string& device_type); +}; + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_EDGE_TPU_CONTEXT_FACTORY_H_ diff --git a/tflite/edgetpu_delegate_for_custom_op.cc b/tflite/edgetpu_delegate_for_custom_op.cc new file mode 100644 index 0000000..c4083a7 --- /dev/null +++ b/tflite/edgetpu_delegate_for_custom_op.cc @@ -0,0 +1,118 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/edgetpu_delegate_for_custom_op.h" + +#include +#include + +#include "port/logging.h" +#include "tflite/public/edgetpu.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/util.h" + +using tflite::ConvertVectorToTfLiteIntArray; +using tflite::TfLiteIntArrayView; + +namespace platforms { +namespace darwinn { +namespace tflite { +namespace { + +constexpr char kDelegateName[] = "EdgeTpuDelegateForCustomOp"; +constexpr int kDelegateVersion = 1; + +void* DelegateInit(TfLiteContext* context, const char* buffer, size_t length) { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + CHECK(params); + + TfLiteIntArray* nodes = params->nodes_to_replace; + CHECK_EQ(nodes->size, 1); + const int node_index = nodes->data[0]; + + TfLiteNode* node; + TfLiteRegistration* registration; + CHECK(context->GetNodeAndRegistration(context, node_index, &node, + ®istration) == kTfLiteOk); + + return edgetpu::RegisterCustomOp()->init( + context, static_cast(node->custom_initial_data), + node->custom_initial_data_size); +} + +TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteDelegate* delegate) { + context->SetExternalContext( + context, kTfLiteEdgeTpuContext, + static_cast(delegate->data_)); + + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + std::vector edgetpu_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (registration->custom_name && + std::strcmp(registration->custom_name, edgetpu::kCustomOp) == 0) { + edgetpu_nodes.push_back(node_index); + } + } + + TfLiteRegistration registration = *edgetpu::RegisterCustomOp(); + registration.init = DelegateInit; + registration.custom_name = kDelegateName; + registration.version = kDelegateVersion; + + for (int node_index : edgetpu_nodes) { + TfLiteIntArray* nodes = ConvertVectorToTfLiteIntArray({node_index}); + context->ReplaceNodeSubsetsWithDelegateKernels( + context, registration, nodes, delegate); + TfLiteIntArrayFree(nodes); + } + + return kTfLiteOk; +} + +class EdgeTpuDelegateForCustomOp : public TfLiteDelegate { + public: + EdgeTpuDelegateForCustomOp(std::shared_ptr context) + : TfLiteDelegate(TfLiteDelegateCreate()), context_(context) { + this->data_ = context.get(); + this->Prepare = PrepareImpl; + this->flags = kTfLiteDelegateFlagsAllowDynamicTensors; + } + + private: + std::shared_ptr context_; +}; + +} // namespace + +TfLiteDelegate* CreateEdgeTpuDelegateForCustomOp( + std::shared_ptr context) { + return context ? new EdgeTpuDelegateForCustomOp(context) : nullptr; +} + +void FreeEdgeTpuDelegateForCustomOp(TfLiteDelegate* delegate) { + delete static_cast(delegate); +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/edgetpu_delegate_for_custom_op.h b/tflite/edgetpu_delegate_for_custom_op.h new file mode 100644 index 0000000..bea6e4e --- /dev/null +++ b/tflite/edgetpu_delegate_for_custom_op.h @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_EDGETPU_DELEGATE_FOR_CUSTOM_OP_H_ +#define DARWINN_TFLITE_EDGETPU_DELEGATE_FOR_CUSTOM_OP_H_ + +#include + +#include "tflite/public/edgetpu.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// Creates delegate instance which enables `tflite::Interpreter` to support +// edge TPU custom op. Returns `nullptr` if context contains `nullptr`. +TfLiteDelegate* CreateEdgeTpuDelegateForCustomOp( + std::shared_ptr context); + +// Deletes created delegate instance, `delegate` may be `nullptr`. +void FreeEdgeTpuDelegateForCustomOp(TfLiteDelegate* delegate); + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_EDGETPU_DELEGATE_FOR_CUSTOM_OP_H_ diff --git a/tflite/edgetpu_delegate_for_custom_op_tflite_plugin.cc b/tflite/edgetpu_delegate_for_custom_op_tflite_plugin.cc new file mode 100644 index 0000000..725536b --- /dev/null +++ b/tflite/edgetpu_delegate_for_custom_op_tflite_plugin.cc @@ -0,0 +1,117 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "tflite/edgetpu_delegate_for_custom_op.h" +#include "tflite/public/edgetpu.h" +#include "tensorflow/lite/builtin_op_data.h" + +namespace { + +using edgetpu::DeviceType; +using edgetpu::EdgeTpuContext; +using edgetpu::EdgeTpuManager; +using DeviceOptions = EdgeTpuManager::DeviceOptions; + +typedef void (*ErrorHandler)(const char*); + +constexpr char kUsb[] = "usb"; +constexpr char kPci[] = "pci"; +constexpr char kOptionDevice[] = "device"; +constexpr int kAnyDevice = -1; + +bool MatchDevice(const std::string& s, const std::string& type, int* index) { + const auto prefix(type + ":"); + if (!absl::StartsWith(s, prefix)) return false; + if (!absl::SimpleAtoi(s.substr(prefix.size()), index)) return false; + if (*index < 0) return false; + return true; +} + +std::shared_ptr GetEdgeTpuContext( + DeviceType type, int index, const EdgeTpuManager::DeviceOptions& options) { + auto* manager = EdgeTpuManager::GetSingleton(); + if (index < 0) { + return manager->OpenDevice(type); + } else { + int i = 0; + for (auto& record : manager->EnumerateEdgeTpu()) + if (record.type == type && i++ == index) + return manager->OpenDevice(record.type, record.path); + return nullptr; + } +} + +std::shared_ptr GetEdgeTpuContext( + const DeviceOptions& options) { + auto it = options.find(kOptionDevice); + if (it == options.end()) { + return EdgeTpuManager::GetSingleton()->OpenDevice(); + } else { + const auto& device = it->second; + if (device == kUsb) { + return GetEdgeTpuContext(DeviceType::kApexUsb, kAnyDevice, options); + } else if (device == kPci) { + return GetEdgeTpuContext(DeviceType::kApexPci, kAnyDevice, options); + } else { + int index; + if (MatchDevice(device, kUsb, &index)) { + return GetEdgeTpuContext(DeviceType::kApexUsb, index, options); + } else if (MatchDevice(device, kPci, &index)) { + return GetEdgeTpuContext(DeviceType::kApexPci, index, options); + } else { + return nullptr; + } + } + } +} + +} // namespace + +extern "C" { + +// Recognized input options: +// "device": ["usb", "usb:", "pci", "pci:"] +// +// "usb" or "pci" define any available USB/PCI TPU device. +// "usb:" or "pci:" define specific USB/PCI TPU device +// according to the enumeration order from +// `edgetpu::EdgeTpuManager::EnumerateEdgeTpu` call. +// +// All options are forwarded to `edgetpu::EdgeTpuManager::OpenDevice` +// call when "device" has a form of "usb:" or "pci:", i.e. the +// following are supported as well: +// "Performance": ["Low", "Medium", "High", "Max"] (Default is "Max") +// "Usb.AlwaysDfu": ["True", "False"] (Default is "False") +// "Usb.MaxBulkInQueueLength": ["0",.., "255"] (Default is "32") +// +// Any availabe TPU device is used if "device" option is not specified. +EDGETPU_EXPORT TfLiteDelegate* tflite_plugin_create_delegate( + char** options_keys, char** options_values, size_t num_options, + ErrorHandler error_handler) { + DeviceOptions options; + for (size_t i = 0; i < num_options; ++i) + options[options_keys[i]] = options_values[i]; + + auto context = GetEdgeTpuContext(options); + if (!context) return nullptr; + return platforms::darwinn::tflite::CreateEdgeTpuDelegateForCustomOp(context); +} + +EDGETPU_EXPORT void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate) { + platforms::darwinn::tflite::FreeEdgeTpuDelegateForCustomOp(delegate); +} + +} // extern "C" diff --git a/tflite/edgetpu_manager_direct.cc b/tflite/edgetpu_manager_direct.cc new file mode 100644 index 0000000..f1e7a09 --- /dev/null +++ b/tflite/edgetpu_manager_direct.cc @@ -0,0 +1,526 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tflite/edgetpu_manager_direct.h" + +#include + +#include "absl/strings/str_format.h" +#include "api/driver_factory.h" +#include "api/driver_options_generated.h" +#include "api/runtime_version.h" +#include "port/builddata.h" +#include "port/defs.h" +#include "port/logging.h" +#include "port/std_mutex_lock.h" +#include "port/stringprintf.h" +#include "tflite/edgetpu_context_direct.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +using edgetpu::EdgeTpuContext; + +EdgeTpuManagerDirect* EdgeTpuManagerDirect::GetSingleton() { + // Static objects with non-trivial destructors shouldn't be deleted, + // according to coding style requirement. + static auto* const impl = + new platforms::darwinn::tflite::EdgeTpuManagerDirect(); + return impl; +} + +std::unique_ptr EdgeTpuManagerDirect::NewEdgeTpuContext() { + StdMutexLock lock(&mutex_); + + return NewEdgeTpuContextInternal(DeviceTypeExtended::kApexAny, std::string(), + EdgeTpuManager::DeviceOptions()); +} + +std::unique_ptr +EdgeTpuManagerDirect::NewEdgeTpuContext(edgetpu::DeviceType device_type) { + StdMutexLock lock(&mutex_); + + return NewEdgeTpuContextInternal(static_cast(device_type), + std::string(), + EdgeTpuManager::DeviceOptions()); +} + +std::unique_ptr +EdgeTpuManagerDirect::NewEdgeTpuContext(edgetpu::DeviceType device_type, + const std::string& device_path) { + StdMutexLock lock(&mutex_); + +#if defined(THROTTLE_EDGE_TPU) && THROTTLE_EDGE_TPU != 0 + // In some cases, we need to throttle edgetpu, see b/119426047 for more + // context. Throttling only applies when EdgeTpu is connected through USB. + // THROTTLE_EDGE_TPU = undefined/0: Max; 1: High; 2: Med; 3: Low; others: High + if (device_type == edgetpu::DeviceType::kApexUsb) { + VLOG(2) << "EdgeTpu is throttled."; + std::string performance_str = "High"; +#if THROTTLE_EDGE_TPU == 2 + performance_str = "Medium"; +#elif THROTTLE_EDGE_TPU == 3 + performance_str = "Low"; +#endif + return NewEdgeTpuContextInternal( + static_cast(device_type), device_path, + {{"Performance", performance_str}}); + } +#endif + + return NewEdgeTpuContextInternal(static_cast(device_type), + device_path, + EdgeTpuManager::DeviceOptions()); +} + +std::unique_ptr +EdgeTpuManagerDirect::NewEdgeTpuContext( + edgetpu::DeviceType device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) { + StdMutexLock lock(&mutex_); + + return NewEdgeTpuContextInternal(static_cast(device_type), + device_path, options); +} + +std::vector +EdgeTpuManagerDirect::EnumerateEdgeTpu() const { + StdMutexLock lock(&mutex_); + + return EnumerateEdgeTpuInternal(); +} + +std::shared_ptr EdgeTpuManagerDirect::OpenDevice() { + StdMutexLock lock(&mutex_); + + return OpenDeviceInternal(DeviceTypeExtended::kApexAny, std::string(), + EdgeTpuManager::DeviceOptions()); +} + +std::shared_ptr EdgeTpuManagerDirect::OpenDevice( + edgetpu::DeviceType device_type) { + StdMutexLock lock(&mutex_); + + return OpenDeviceInternal(static_cast(device_type), + std::string(), EdgeTpuManager::DeviceOptions()); +} + +std::shared_ptr EdgeTpuManagerDirect::OpenDevice( + edgetpu::DeviceType device_type, const std::string& device_path) { + StdMutexLock lock(&mutex_); + +#if defined(THROTTLE_EDGE_TPU) && THROTTLE_EDGE_TPU != 0 + // In some cases, we need to throttle edgetpu, see b/119426047 for more + // context. Throttling only applies when EdgeTpu is connected through USB. + // THROTTLE_EDGE_TPU = undefined/0: Max; 1: High; 2: Med; 3: Low; others: High + if (device_type == edgetpu::DeviceType::kApexUsb) { + VLOG(2) << "EdgeTpu is throttled."; + std::string performance_str = "High"; +#if THROTTLE_EDGE_TPU == 2 + performance_str = "Medium"; +#elif THROTTLE_EDGE_TPU == 3 + performance_str = "Low"; +#endif + return OpenDeviceInternal(static_cast(device_type), + device_path, + {{"Performance", performance_str}}); + } +#endif + + return OpenDeviceInternal(static_cast(device_type), + device_path, EdgeTpuManager::DeviceOptions()); +} + +std::shared_ptr EdgeTpuManagerDirect::OpenDevice( + edgetpu::DeviceType device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) { + StdMutexLock lock(&mutex_); + + return OpenDeviceInternal(static_cast(device_type), + device_path, options); +} + +std::vector> +EdgeTpuManagerDirect::GetOpenedDevices() const { + StdMutexLock lock(&mutex_); + + std::vector> results; + + for (auto& drvier_wrapper : opened_devices_) { + if (drvier_wrapper->IsExclusivelyOwned()) { + // Skips devices that are not sharable + continue; + } + + auto shared_context = + std::make_shared(drvier_wrapper.get()); + + // Note that we only keeps the weak pointer, as usually it is + results.push_back(shared_context); + } + + return results; +} + +TfLiteStatus EdgeTpuManagerDirect::SetVerbosity(int verbosity) { + StdMutexLock lock(&mutex_); + + if (verbosity < 0 || verbosity > 10) { + return kTfLiteError; + } + + // Update verbosity level. + // TODO: Verbosity level should be of per driver instance. +#if !DARWINN_PORT_USE_GOOGLE3 + ::platforms::darwinn::internal::SetLoggingLevel(verbosity); + return kTfLiteOk; +#else + // Assume FLAGS_v is defined. + absl::SetFlag(&FLAGS_v, verbosity); + return kTfLiteOk; +#endif +} + +std::string EdgeTpuManagerDirect::Version() const { + StdMutexLock lock(&mutex_); + + absl::string_view build_label = BuildData::BuildLabel(); + // Note that runtime version reported here is correct only if all driver + // providers are built at the same time with this compile unit. + if (build_label.empty()) { + return StringPrintf("BuildLabel(N/A), RuntimeVersion(%d)", + api::RuntimeVersion::kCurrent); + } else { + return absl::StrFormat("BuildLabel(%s), RuntimeVersion(%d)", build_label, + api::RuntimeVersion::kCurrent); + } +} + +void EdgeTpuManagerDirect::ReleaseEdgeTpuContext( + EdgeTpuDriverWrapper* driver_wrapper) { + StdMutexLock lock(&mutex_); + + for (auto it = opened_devices_.begin(); it != opened_devices_.end(); ++it) { + if (it->get() == driver_wrapper) { + // Perform weak_ptr::lock here and create a new instance of shared_ptr to + // the context object. + auto use_count = (*it)->Release(); + if (use_count > 0) { + // This could happen when the last shared pointer to the context object + // is + VLOG(1) << "Edge TPU device at " << (*it)->GetDeviceEnumRecord().path + << " is still in use."; + } else { + VLOG(4) << "Releasing Edge TPU device at " + << (*it)->GetDeviceEnumRecord().path; + opened_devices_.erase(it); + } + return; + } + } + + LOG(FATAL) << "Could not find specified Edge TPU context to close."; +} + +std::vector +EdgeTpuManagerDirect::EnumerateEdgeTpuInternal() const { + std::vector result; + + auto factory = api::DriverFactory::GetOrCreate(); + if (!factory) { + VLOG(1) << "Failed to create driver factory."; + return result; + } + + auto devices = factory->Enumerate(); + + for (const auto& device : devices) { + edgetpu::DeviceType device_type = + static_cast(DeviceTypeExtended::kUnknown); + if (device.chip == api::Chip::kBeagle) { + switch (device.type) { + case api::Device::Type::PCI: + device_type = edgetpu::DeviceType::kApexPci; + break; + case api::Device::Type::USB: + device_type = edgetpu::DeviceType::kApexUsb; + break; + case api::Device::Type::REFERENCE: + device_type = static_cast( + DeviceTypeExtended::kApexReference); + break; + default: + VLOG(7) << "Skipping unrecognized device type: " + << static_cast(device.type); + continue; + } + } else { + VLOG(7) << "Skipping unrecognized Edge TPU type: " + << static_cast(device.chip); + continue; + } + + result.push_back(edgetpu::EdgeTpuManager::DeviceEnumerationRecord{ + device_type, device.path}); + } + return result; +} + +std::string EdgeTpuManagerDirect::FindPathToFirstUnopenedDevice( + const std::vector& candidates, + edgetpu::DeviceType request_device_type) { + for (auto& device : candidates) { + if (request_device_type != device.type) { + // Skips devices of different type. + continue; + } + + bool is_opened = false; + + // So this candidate is of the same type as requested. + // Check if it's already opened. + for (auto& drvier_wrapper : opened_devices_) { + const auto& enum_record = drvier_wrapper->GetDeviceEnumRecord(); + + if ((device.type == enum_record.type) && + (device.path == enum_record.path)) { + // A device of the same type and path is already registereed as open. + is_opened = true; + break; + } + } + + if (!is_opened) { + // We just found a device of the requested type and is not opened yet. + return device.path; + } + } + + // We couldn't find any suitable device. + return std::string(); +} + +std::shared_ptr EdgeTpuManagerDirect::TryMatchDriverWrapper( + const std::vector& extended_device_types, + const std::string& extended_device_path) { + // Iterates through all requested device types. + for (auto request_device_type : extended_device_types) { + for (auto& drvier_wrapper : opened_devices_) { + const auto& enum_record = drvier_wrapper->GetDeviceEnumRecord(); + if (request_device_type != enum_record.type) { + // Skips devices of different type. + continue; + } + + // Only check the path if it's not empty. + if (!extended_device_path.empty()) { + if (extended_device_path != enum_record.path) { + // Skips devices of different path. + continue; + } + } + + if (drvier_wrapper->IsExclusivelyOwned()) { + // Skips devices that are not sharable + continue; + } + + // Now we finally find a device to be shared. + // Note we don't check for compatibility in options. + return std::make_shared(drvier_wrapper.get()); + } + } + + // Returns null pointer in case we cannot find a matching device. + return std::shared_ptr(); +} + +std::unique_ptr EdgeTpuManagerDirect::MakeDriverWrapper( + edgetpu::DeviceType request_device_type, + const std::string& extended_device_path, + const EdgeTpuManager::DeviceOptions& options, bool exclusive_ownership) { + auto driver = EdgeTpuDriverWrapper::MakeOpenedDriver( + request_device_type, extended_device_path, options); + + if (driver) { + EdgeTpuManager::DeviceEnumerationRecord enum_record; + enum_record.path = extended_device_path; + enum_record.type = request_device_type; + return gtl::MakeUnique(std::move(driver), enum_record, + options, exclusive_ownership); + } + + // In case we cannot create a new driver with its wrapper, return a null + // pointer. + return std::unique_ptr(); +} + +std::unique_ptr EdgeTpuManagerDirect::NewEdgeTpuContextInternal( + DeviceTypeExtended device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) { + bool allow_any_path = false; + auto extended_device_types = ExtendRequestDeviceType(device_type); + auto extended_device_path = device_path; + if (extended_device_path.empty() || + (extended_device_path == api::DriverFactory::kDefaultDevicePath)) { + allow_any_path = true; + } + + // Get all connected devices. Note this enumeration results do not consider if + // the devices have been opened or not. + auto candidates = EnumerateEdgeTpuInternal(); + + // Iterates through all requested device types. + for (auto request_device_type : extended_device_types) { + if (allow_any_path) { + // Overwrites extended_device_path here, as driver factory + // doesn't actually report back the exact path of the opened device. + extended_device_path = + FindPathToFirstUnopenedDevice(candidates, request_device_type); + + if (extended_device_path.empty()) { + // There is no un-opened device of this particular type. + // It's okay to leave extended_device_path as empty, for we already have + // allow_any_path set to true. + VLOG(5) << "No device of type " + << EdgeTpuDriverWrapper::GetDeviceTypeName( + static_cast(request_device_type)) + << " is available."; + continue; + } + } + + // We have a path. Now try to open and creates a wrapper for it. + std::unique_ptr driver_wrapper = + MakeDriverWrapper(request_device_type, extended_device_path, options, + /*exclusive_ownership=*/true); + + if (!driver_wrapper) { + VLOG(1) + << "Failed creating new Edge TPU context for exclusive ownership."; + + // Returns a null pointer on error. + return std::unique_ptr(); + } + + EdgeTpuDriverWrapper* wrapper_pointer = driver_wrapper.get(); + + // Commits the driver wrapper into opened devices. + opened_devices_.push_back(std::move(driver_wrapper)); + + return gtl::MakeUnique(wrapper_pointer); + } + + VLOG(1) << "Failed allocating Edge TPU device for exclusive ownership."; + + return std::unique_ptr(); +} + +std::shared_ptr EdgeTpuManagerDirect::OpenDeviceInternal( + DeviceTypeExtended device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) { + bool allow_any_path = false; + auto extended_device_types = ExtendRequestDeviceType(device_type); + auto extended_device_path = device_path; + if (extended_device_path.empty() || + (extended_device_path == api::DriverFactory::kDefaultDevicePath)) { + allow_any_path = true; + extended_device_path.clear(); + } + + // Tries to find a match in all the device types. + // The returned shared_ptr would prevent the device from being closed. + auto context = + TryMatchDriverWrapper(extended_device_types, extended_device_path); + + if (context) { + // Returns the opened device. + return context; + } + + VLOG(5) << "No matching device is already opened for shared ownership."; + + // Get all connected devices. Note this enumeration results do not consider if + // the devices have been opened or not. + auto candidates = EnumerateEdgeTpuInternal(); + + // Iterates through all requested device types. + for (auto request_device_type : extended_device_types) { + if (allow_any_path) { + // We have to overwrite extended_device_path here, as driver factory + // doesn't actually report back the exact path of the opened device. + extended_device_path = + FindPathToFirstUnopenedDevice(candidates, request_device_type); + + if (extended_device_path.empty()) { + // There is no un-opened device of this particular type. + // It's okay to leave extended_device_path as empty, for we already have + // allow_any_path set to true. + VLOG(5) << "No device of type " + << EdgeTpuDriverWrapper::GetDeviceTypeName( + static_cast(request_device_type)) + << " is available."; + continue; + } + } + + std::unique_ptr driver_wrapper = + MakeDriverWrapper(request_device_type, extended_device_path, options, + /*exclusive_ownership=*/false); + + if (!driver_wrapper) { + // Returns a null pointer on error. + return std::shared_ptr(); + } + + EdgeTpuDriverWrapper* wrapper_pointer = driver_wrapper.get(); + + // Commits the driver wrapper into opened devices. + opened_devices_.push_back(std::move(driver_wrapper)); + + return std::make_shared(wrapper_pointer); + } + + VLOG(1) << "Failed allocating Edge TPU device for shared ownership."; + + return std::shared_ptr(); +} + +std::vector EdgeTpuManagerDirect::ExtendRequestDeviceType( + DeviceTypeExtended device_type) { + std::vector request_device_types; + + if (device_type == DeviceTypeExtended::kApexAny) { + // If the device type is Apex Any, try all supported types + // 1st priority: PCIe + request_device_types.push_back(edgetpu::DeviceType::kApexPci); + + // 2nd priority: USB + request_device_types.push_back(edgetpu::DeviceType::kApexUsb); + + // 3rd priority: Reference device + request_device_types.push_back( + static_cast(DeviceTypeExtended::kApexReference)); + } else { + request_device_types.push_back( + static_cast(device_type)); + } + + return request_device_types; +} + +} // namespace tflite +} // namespace darwinn +} // namespace platforms diff --git a/tflite/edgetpu_manager_direct.h b/tflite/edgetpu_manager_direct.h new file mode 100644 index 0000000..f120996 --- /dev/null +++ b/tflite/edgetpu_manager_direct.h @@ -0,0 +1,147 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DARWINN_TFLITE_EDGETPU_MANAGER_DIRECT_H_ +#define DARWINN_TFLITE_EDGETPU_MANAGER_DIRECT_H_ + +#include // NOLINT + +#include "port/thread_annotations.h" +#include "tflite/edgetpu_context_direct.h" +#include "tflite/public/edgetpu.h" + +namespace platforms { +namespace darwinn { +namespace tflite { + +// TPU Manager implementation for direct API. +// This class is threadsafe, as multiple TPU contexts, driven by multiple +// Interpreter threads, could access this singleton at the same time. +class EdgeTpuManagerDirect : public edgetpu::EdgeTpuManager { + public: + virtual ~EdgeTpuManagerDirect() = default; + + static EdgeTpuManagerDirect* GetSingleton(); + + std::unique_ptr NewEdgeTpuContext() final + LOCKS_EXCLUDED(mutex_); + + std::unique_ptr NewEdgeTpuContext( + DeviceType device_type) final LOCKS_EXCLUDED(mutex_); + + std::unique_ptr NewEdgeTpuContext( + DeviceType device_type, const std::string& device_path) final + LOCKS_EXCLUDED(mutex_); + + std::unique_ptr NewEdgeTpuContext( + DeviceType device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) final + LOCKS_EXCLUDED(mutex_); + + std::vector EnumerateEdgeTpu() const final + LOCKS_EXCLUDED(mutex_); + + std::shared_ptr OpenDevice() final LOCKS_EXCLUDED(mutex_); + + std::shared_ptr OpenDevice(DeviceType device_type) final + LOCKS_EXCLUDED(mutex_); + + std::shared_ptr OpenDevice( + DeviceType device_type, const std::string& device_path) final + LOCKS_EXCLUDED(mutex_); + + std::shared_ptr OpenDevice(DeviceType device_type, + const std::string& device_path, + const DeviceOptions& options) final + LOCKS_EXCLUDED(mutex_); + + std::vector> GetOpenedDevices() const final + LOCKS_EXCLUDED(mutex_); + + TfLiteStatus SetVerbosity(int verbosity) final LOCKS_EXCLUDED(mutex_); + + std::string Version() const final LOCKS_EXCLUDED(mutex_); + + // Intended to be used by #~EdgeTpuContextDirect() + // When EdgeTpuContextDirect is destructed, it would release reference to the + // underlying EdgeTpuDriverWrapper with this call. + void ReleaseEdgeTpuContext(EdgeTpuDriverWrapper* driver_wrapper) + LOCKS_EXCLUDED(mutex_); + + private: + EdgeTpuManagerDirect() = default; + + // Enumerates connected TPU devices. + std::vector EnumerateEdgeTpuInternal() const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns path to the first device of requested type which is not yet opened. + // Reutnrs empty string if no such device can be found. + std::string FindPathToFirstUnopenedDevice( + const std::vector& candidates, + edgetpu::DeviceType request_device_type) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns a shared pointer to existing/opened driver wrapper of the specified + // device types and path. The path argument can be empty + // string so it matches with any device path. + std::shared_ptr TryMatchDriverWrapper( + const std::vector& extended_device_types, + const std::string& extended_device_path) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Constructs device driver instance, opens it, and then returns an unique + // pointer to the driver wrapper. + std::unique_ptr MakeDriverWrapper( + edgetpu::DeviceType request_device_type, + const std::string& extended_device_path, + const EdgeTpuManager::DeviceOptions& options, bool exclusive_ownership) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns an unique pointer to EdgeTpuContext, which is opened for exclusive + // ownership. + // + // #device_type can be kApexAny, as it's further extended by + // ExtendRequestDeviceType. #device_path can be either empty or + // api::DriverFactory::kDefaultDevicePath, both match to any device path. + std::unique_ptr NewEdgeTpuContextInternal( + DeviceTypeExtended device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns a shared pointer to EdgeTpuContext, which is opened for shared + // ownership. + // + // #device_type can be kApexAny, as it's further extended by + // ExtendRequestDeviceType. #device_path can be either empty or + // api::DriverFactory::kDefaultDevicePath, both match to any device path. + std::shared_ptr OpenDeviceInternal( + DeviceTypeExtended device_type, const std::string& device_path, + const EdgeTpuManager::DeviceOptions& options) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Extends wildcard device type to a list of types. + static std::vector ExtendRequestDeviceType( + DeviceTypeExtended device_type); + + // Serializes access to this singleton. + mutable std::mutex mutex_; + + std::vector> opened_devices_ + GUARDED_BY(mutex_); +}; + +} // namespace tflite +} // namespace darwinn +} // namespace platforms + +#endif // DARWINN_TFLITE_EDGETPU_MANAGER_DIRECT_H_ diff --git a/tflite/public/BUILD b/tflite/public/BUILD new file mode 100644 index 0000000..63db841 --- /dev/null +++ b/tflite/public/BUILD @@ -0,0 +1,160 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# BUILD rules for DarwiNN TfLite Custom-op public interface. + +package( + default_visibility = ["//visibility:public"], +) + +# This group is meant for users of USB TPUs who also have a dependency to +# libcap (ex: CNS file support depends on libcap). Using the default USB +# driver wouldn't work in that case as libUSB is dynamically linked and +# thus the global _cap_names in libcap ends up twice in your binary, +# causing an ODR violation. Also to comply with libUSB LGPL license, this +# version only exist for internal usage (no external release, this is why +# the user are limited to this group). +package_group( + name = "libusb_statically_linked_users", + packages = [ + "//lifescience/cad/ophthalmology/arda_camera/dagon/...", + "//vr/stargate/...", + ], +) + +licenses(["notice"]) + +# Version number used in the soname of shared libraries. +VERSION = "1" + +SHARED_LIBRARY_LINKOPTS = [ + "-Wl,-soname,libedgetpu.so." + VERSION, + "-Wl,--version-script=$(location libedgetpu.lds)", +] + +# Header for external use. +cc_library( + name = "edgetpu", + hdrs = [ + "edgetpu.h", + ], + defines = select({ + "//:windows": ["EDGETPU_COMPILE_LIBRARY"], + "//conditions:default": [], + }), + deps = [ + "@org_tensorflow//tensorflow/lite:context", + ], +) + +cc_library( + name = "edgetpu_c", + hdrs = [ + "edgetpu_c.h", + ], + defines = select({ + "//:windows": ["EDGETPU_COMPILE_LIBRARY"], + "//conditions:default": [], + }), + deps = [ + "@org_tensorflow//tensorflow/lite:context", + ], +) + +# Shared library for external use. +# Explicit variant for all(pci/usb). +cc_binary( + name = "libedgetpu_direct_all.so", + linkopts = SHARED_LIBRARY_LINKOPTS, + linkshared = 1, + linkstatic = 1, + deps = [ + "libedgetpu.lds", + "//driver/beagle:beagle_all_driver_provider", + "//tflite:custom_op_direct", + ], +) + +# Shared library for external use. +# Explicit variant for Beagle PCIe. +cc_binary( + name = "libedgetpu_direct_pci.so", + linkopts = SHARED_LIBRARY_LINKOPTS, + linkshared = 1, + linkstatic = 1, + deps = [ + "libedgetpu.lds", + "//driver/beagle:beagle_pci_driver_provider", + "//tflite:custom_op_direct", + ], +) + +# Shared library for linking of applications not depending on a particular driver provider. +cc_binary( + name = "libedgetpu_bare.so", + linkopts = SHARED_LIBRARY_LINKOPTS, + linkshared = 1, + linkstatic = 1, + deps = [ + "libedgetpu.lds", + "//tflite:custom_op_direct", + ], +) + +# Shared library for external use. +# Explicit variant for Beagle USB. +cc_binary( + name = "libedgetpu_direct_usb.so", + linkopts = SHARED_LIBRARY_LINKOPTS, + linkshared = 1, + linkstatic = 1, + deps = [ + "libedgetpu.lds", + "//driver/beagle:beagle_usb_driver_provider", + "//tflite:custom_op_direct", + ], +) + +cc_binary( + name = "libedgetpu_direct_usb.dylib", + linkopts = [ + "-Wl,-install_name,@rpath/libedgetpu." + VERSION + ".dylib", + ], + linkshared = 1, + linkstatic = 1, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [ + "//driver/beagle:beagle_usb_driver_provider", + "//tflite:custom_op_direct", + ], +) + +cc_binary( + name = "edgetpu_direct_usb.dll", + linkshared = 1, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [ + "//driver/beagle:beagle_usb_driver_provider", + "//tflite:custom_op_direct", + "@libusb//:shared", + ], +) diff --git a/tflite/public/edgetpu.h b/tflite/public/edgetpu.h new file mode 100644 index 0000000..2f8f18e --- /dev/null +++ b/tflite/public/edgetpu.h @@ -0,0 +1,290 @@ +/* +Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// +// This header file defines EdgeTpuManager, and EdgeTpuContext. +// EdgeTpuContext is an object associated with one or more tflite::Interpreter. +// Instances of this class should be allocated through +// EdgeTpuManager::NewEdgeTpuContext. +// More than one Interpreter instances can point to the same context. This means +// the tasks from both would be executed under the same TPU context. +// The lifetime of this context must be longer than all associated +// tflite::Interpreter instances. +// +// Typical usage with NNAPI: +// +// std::unique_ptr interpreter; +// tflite::ops::builtin::BuiltinOpResolver resolver; +// auto model = +// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str()); +// +// // Registers edge TPU custom op handler with Tflite resolver. +// resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); +// +// tflite::InterpreterBuilder(*model, resolver)(&interpreter); +// +// interpreter->AllocateTensors(); +// .... (Prepare input tensors) +// interpreter->Invoke(); +// .... (retrieving the result from output tensors) +// +// // Releases interpreter instance to free up resources associated with +// // this custom op. +// interpreter.reset(); +// +// Typical usage with Non-NNAPI: +// +// // Sets up the tpu_context. +// auto tpu_context = +// edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); +// +// std::unique_ptr interpreter; +// tflite::ops::builtin::BuiltinOpResolver resolver; +// auto model = +// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str()); +// +// // Registers edge TPU custom op handler with Tflite resolver. +// resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); +// +// tflite::InterpreterBuilder(*model, resolver)(&interpreter); +// +// // Binds a context with a specific interpreter. +// interpreter->SetExternalContext(kTfLiteEdgeTpuContext, +// tpu_context.get()); +// +// // Note that all edge TPU context set ups should be done before this +// // function is called. +// interpreter->AllocateTensors(); +// .... (Prepare input tensors) +// interpreter->Invoke(); +// .... (retrieving the result from output tensors) +// +// // Releases interpreter instance to free up resources associated with +// // this custom op. +// interpreter.reset(); +// +// // Closes the edge TPU. +// tpu_context.reset(); + +#ifndef TFLITE_PUBLIC_EDGETPU_H_ +#define TFLITE_PUBLIC_EDGETPU_H_ + +// If the ABI changes in a backward-incompatible way, please increment the +// version number in the BUILD file. + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/context.h" + +#if defined(_WIN32) +#ifdef EDGETPU_COMPILE_LIBRARY +#define EDGETPU_EXPORT __declspec(dllexport) +#else +#define EDGETPU_EXPORT __declspec(dllimport) +#endif // EDGETPU_COMPILE_LIBRARY +#else +#define EDGETPU_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +namespace edgetpu { + +// EdgeTPU custom op. +static const char kCustomOp[] = "edgetpu-custom-op"; + +enum class DeviceType { + kApexPci = 0, + kApexUsb = 1, +}; + +class EdgeTpuContext; + +// Singleton edge TPU manager for allocating new TPU contexts. +// Functions in this interface are thread-safe. +class EDGETPU_EXPORT EdgeTpuManager { + public: + using DeviceOptions = std::unordered_map; + struct DeviceEnumerationRecord { + DeviceType type; + std::string path; + + // Returns true if two enumeration records point to the same device. + friend bool operator==(const DeviceEnumerationRecord& lhs, + const DeviceEnumerationRecord& rhs) { + return (lhs.type == rhs.type) && (lhs.path == rhs.path); + } + + // Returns true if two enumeration records point to defferent devices. + friend bool operator!=(const DeviceEnumerationRecord& lhs, + const DeviceEnumerationRecord& rhs) { + return !(lhs == rhs); + } + }; + + // Returns pointer to the singleton object, or nullptr if not supported on + // this platform. + static EdgeTpuManager* GetSingleton(); + + // NewEdgeTpuContext family functions has been deprecated and will be removed + // in the future. Please use OpenDevice for new code. + // + // These functions return an unique_ptr to EdgeTpuContext, with + // the intention that the device will be closed, and associate resources + // released, when the unique_ptr leaves scope. + // + // These functions seek exclusive ownership of the opened devices. As they + // cannot open devices already opened by OpenDevice, and vice versa. + // Devices opened through these functions would have attribute + // "ExclusiveOwnership", which can be queried through + // #EdgeTpuContext::GetDeviceOptions(). + + // Creates a new Edge TPU context to be assigned to Tflite::Interpreter. The + // Edge TPU context is associated with the default TPU device. May be null + // if underlying device cannot be found or open. Caller owns the returned new + // context and should destroy the context either implicity or explicitly after + // all interpreters sharing this context are destroyed. + virtual std::unique_ptr NewEdgeTpuContext() = 0; + + // Same as above, but the created context is associated with the specified + // type. + virtual std::unique_ptr NewEdgeTpuContext( + DeviceType device_type) = 0; + + // Same as above, but the created context is associated with the specified + // type and device path. + virtual std::unique_ptr NewEdgeTpuContext( + DeviceType device_type, const std::string& device_path) = 0; + + // Same as above, but the created context is associated with the given device + // type, path and options. + // + // Available options are: + // - "Performance": ["Low", "Medium", "High", "Max"] (Default is "Max") + // - "Usb.AlwaysDfu": ["True", "False"] (Default is "False") + // - "Usb.MaxBulkInQueueLength": ["0",.., "255"] (Default is "32") + virtual std::unique_ptr NewEdgeTpuContext( + DeviceType device_type, const std::string& device_path, + const DeviceOptions& options) = 0; + + // Enumerates all connected Edge TPU devices. + virtual std::vector EnumerateEdgeTpu() const = 0; + + // OpenDevice family of functions return a shared_ptr to EdgeTpuContext, with + // the intention that the device can be shared among multiple software + // components. + // + // These functions seek shared ownership of the opened devices. As they + // cannot open devices already opened by NewEdgeTpuContext, and vice versa. + // The device would be closed after the last reference leaves scope. + + // Opens the default Edge TPU device. + // + // Multiple invocations of this function could return handle to the same + // device, but there is no guarantee. + // + // Returns a shared pointer to Edge TPU device. The shared_ptr could point to + // nullptr in case of error. + virtual std::shared_ptr OpenDevice() = 0; + + // Same as above, but the returned context is associated with the specified + // type. + virtual std::shared_ptr OpenDevice( + DeviceType device_type) = 0; + + // Same as above, but the returned context is associated with the specified + // type and device path. If path is empty, any device of the specified type + // could be returned. + virtual std::shared_ptr OpenDevice( + DeviceType device_type, const std::string& device_path) = 0; + + // Same as above, but the specified options would used to create a new context + // if no existing device is compatible with the specified type and path. + // + // If a device of compatible type and path can be found, the options could be + // ignored. It is the caller's responsibility to verify if the returned + // context is desirable, through #EdgeTpuContext::GetDeviceOptions(). + // + // Available options are: + // - "Performance": ["Low", "Medium", "High", "Max"] (Default is "Max") + // - "Usb.AlwaysDfu": ["True", "False"] (Default is "False") + // - "Usb.MaxBulkInQueueLength": ["0",.., "255"] (Default is "32") + virtual std::shared_ptr OpenDevice( + DeviceType device_type, const std::string& device_path, + const DeviceOptions& options) = 0; + + // Returns a snapshot of currently opened shareable devices. + // Exclusively owned Edge TPU devices cannot be returned here, as they're + // owned by unique pointers. + virtual std::vector> GetOpenedDevices() + const = 0; + + // Sets verbosity of operating logs related to edge TPU. + // Verbosity level can be set to [0-10], in which 10 is the most verbose. + virtual TfLiteStatus SetVerbosity(int verbosity) = 0; + + // Returns the version of EdgeTPU runtime stack. + virtual std::string Version() const = 0; + + protected: + // No deletion for this singleton instance. + virtual ~EdgeTpuManager() = default; +}; + +// External context to be assigned through +// tflite::Interpreter::SetExternalContext. +// One should get hold of either shared_ptr from EdgeTpuManager::OpenDevice, or +// unique_ptr from EdgeTpuManager::NewEdgeTpuContext, to ensure ownership, and +// avoid using this pointer directly. +// Functions in this interface are thread-safe. +class EdgeTpuContext : public TfLiteExternalContext { + public: + virtual ~EdgeTpuContext() = 0; + + // Returns a pointer to the device enumeration record for this device, + // if available. + virtual const EdgeTpuManager::DeviceEnumerationRecord& GetDeviceEnumRecord() + const = 0; + + // Returns a snapshot of the options used to open this + // device, and current state, if available. + // + // Supported attributes are: + // - "ExclusiveOwnership": present when it is under exclusive ownership + // (unique_ptr returned by NewEdgeTpuContext). + // - "IsReady": present when it is ready for further requests. + virtual EdgeTpuManager::DeviceOptions GetDeviceOptions() const = 0; + + // Returns true if the device is most likely ready to accept requests. + // When there are fatal errors, including unplugging of an USB device, the + // state of this device would be changed. + virtual bool IsReady() const = 0; +}; + +// Returns pointer to an instance of TfLiteRegistration to handle +// EdgeTPU custom ops, to be used with +// tflite::ops::builtin::BuiltinOpResolver::AddCustom +EDGETPU_EXPORT TfLiteRegistration* RegisterCustomOp(); + +// Inserts name of device type into ostream. Returns the modified ostream. +EDGETPU_EXPORT std::ostream& operator<<(std::ostream& out, + DeviceType device_type); + +} // namespace edgetpu + + +#endif // TFLITE_PUBLIC_EDGETPU_H_ diff --git a/tflite/public/edgetpu_c.h b/tflite/public/edgetpu_c.h new file mode 100644 index 0000000..54f4f05 --- /dev/null +++ b/tflite/public/edgetpu_c.h @@ -0,0 +1,116 @@ +/* +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// +// This header defines C API to provide edge TPU support for TensorFlow Lite +// framework. It is only available for non-NNAPI use cases. +// +// Typical API usage from C++ code involves serveral steps: +// +// 1. Create tflite::FlatBufferModel which may contain edge TPU custom op. +// +// auto model = +// tflite::FlatBufferModel::BuildFromFile(model_file_name.c_str()); +// +// 2. Create tflite::Interpreter. +// +// tflite::ops::builtin::BuiltinOpResolver resolver; +// std::unique_ptr interpreter; +// tflite::InterpreterBuilder(model, resolver)(&interpreter); +// +// 3. Enumerate edge TPU devices. +// +// size_t num_devices; +// std::unique_ptr devices( +// edgetpu_list_devices(&num_devices), &edgetpu_free_devices); +// +// assert(num_devices > 0); +// const auto& device = devices.get()[0]; +// +// 4. Modify interpreter with the delegate. +// +// auto* delegate = +// edgetpu_create_delegate(device.type, device.path, nullptr, 0); +// interpreter->ModifyGraphWithDelegate({delegate, edgetpu_free_delegate}); +// +// 5. Prepare input tensors and run inference. +// +// interpreter->AllocateTensors(); +// .... (Prepare input tensors) +// interpreter->Invoke(); +// .... (Retrieve the result from output tensors) + +#ifndef TFLITE_PUBLIC_EDGETPU_C_H_ +#define TFLITE_PUBLIC_EDGETPU_C_H_ + +#include "tensorflow/lite/context.h" + +#if defined(_WIN32) +#ifdef EDGETPU_COMPILE_LIBRARY +#define EDGETPU_EXPORT __declspec(dllexport) +#else +#define EDGETPU_EXPORT __declspec(dllimport) +#endif // EDGETPU_COMPILE_LIBRARY +#else +#define EDGETPU_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +#ifdef __cplusplus +extern "C" { +#endif + +enum edgetpu_device_type { + EDGETPU_APEX_PCI = 0, + EDGETPU_APEX_USB = 1, +}; + +struct edgetpu_device { + enum edgetpu_device_type type; + const char* path; +}; + +struct edgetpu_option { + const char* name; + const char* value; +}; + +// Returns array of connected edge TPU devices. +EDGETPU_EXPORT struct edgetpu_device* edgetpu_list_devices(size_t* num_devices); + +// Frees array returned by `edgetpu_list_devices`. +EDGETPU_EXPORT void edgetpu_free_devices(struct edgetpu_device* dev); + +// Creates a delegate which handles all edge TPU custom ops inside +// `tflite::Interpreter`. Options must be available only during the call of this +// function. +EDGETPU_EXPORT TfLiteDelegate* edgetpu_create_delegate( + enum edgetpu_device_type type, const char* name, + const struct edgetpu_option* options, size_t num_options); + +// Frees delegate returned by `edgetpu_create_delegate`. +EDGETPU_EXPORT void edgetpu_free_delegate(TfLiteDelegate* delegate); + +// Sets verbosity of operating logs related to edge TPU. +// Verbosity level can be set to [0-10], in which 10 is the most verbose. +EDGETPU_EXPORT void edgetpu_verbosity(int verbosity); + +// Returns the version of edge TPU runtime stack. +EDGETPU_EXPORT const char* edgetpu_version(); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TFLITE_PUBLIC_EDGETPU_C_H_ diff --git a/tflite/public/libedgetpu.lds b/tflite/public/libedgetpu.lds new file mode 100644 index 0000000..6a07c66 --- /dev/null +++ b/tflite/public/libedgetpu.lds @@ -0,0 +1,15 @@ +VER_1.0 { + global: + extern "C++" { + edgetpu::*; + }; + /* Functions to support edge TPU custom op from TFLite Python API, check + * `//third_party/tensorflow/lite/python/interpreter.py` for the spec. + */ + tflite_plugin_*; + /* Edge TPU C API functions. + */ + edgetpu_*; + local: + *; +};