Skip to content
This repository has been archived by the owner on Nov 15, 2021. It is now read-only.

Commit

Permalink
Sync with ray master (#33)
Browse files Browse the repository at this point in the history
* [rllib] Remove dependency on TensorFlow (ray-project#4764)

* remove hard tf dep

* add test

* comment fix

* fix test

* Dynamic Custom Resources - create and delete resources (ray-project#3742)

* Update tutorial link in doc (ray-project#4777)

* [rllib] Implement learn_on_batch() in torch policy graph

* Fix `ray stop` by killing raylet before plasma (ray-project#4778)

* Fatal check if object store dies (ray-project#4763)

* [rllib] fix clip by value issue as TF upgraded (ray-project#4697)

*  fix clip_by_value issue

*  fix typo

* [autoscaler] Fix submit (ray-project#4782)

* Queue tasks in the raylet in between async callbacks (ray-project#4766)

* Add a SWAP TaskQueue so that we can keep track of tasks that are temporarily dequeued

* Fix bug where tasks that fail to be forwarded don't appear to be local by adding them to SWAP queue

* cleanups

* updates

* updates

* [Java][Bazel]  Refine auto-generated pom files (ray-project#4780)

* Bump version to 0.7.0 (ray-project#4791)

* [JAVA] setDefaultUncaughtExceptionHandler to log uncaught exception in user thread. (ray-project#4798)

* Add WorkerUncaughtExceptionHandler

* Fix

* revert bazel and pom

* [tune] Fix CLI test (ray-project#4801)

* Fix pom file generation (ray-project#4800)

* [rllib] Support continuous action distributions in IMPALA/APPO (ray-project#4771)

* [rllib] TensorFlow 2 compatibility (ray-project#4802)

* Change tagline in documentation and README. (ray-project#4807)

* Update README.rst, index.rst, tutorial.rst and  _config.yml

* [tune] Support non-arg submit (ray-project#4803)

* [autoscaler] rsync cluster (ray-project#4785)

* [tune] Remove extra parsing functionality (ray-project#4804)

* Fix Java worker log dir (ray-project#4781)

* [tune] Initial track integration (ray-project#4362)

Introduces a minimally invasive utility for logging experiment results. A broad requirement for this tool is that it should integrate seamlessly with Tune execution.

* [rllib] [RFC] Dynamic definition of loss functions and modularization support (ray-project#4795)

* dynamic graph

* wip

* clean up

* fix

* document trainer

* wip

* initialize the graph using a fake batch

* clean up dynamic init

* wip

* spelling

* use builder for ppo pol graph

* add ppo graph

* fix naming

* order

* docs

* set class name correctly

* add torch builder

* add custom model support in builder

* cleanup

* remove underscores

* fix py2 compat

* Update dynamic_tf_policy_graph.py

* Update tracking_dict.py

* wip

* rename

* debug level

* rename policy_graph -> policy in new classes

* fix test

* rename ppo tf policy

* port appo too

* forgot grads

* default policy optimizer

* make default config optional

* add config to optimizer

* use lr by default in optimizer

* update

* comments

* remove optimizer

* fix tuple actions support in dynamic tf graph

* [rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy/ (ray-project#4819)

This implements some of the renames proposed in ray-project#4813
We leave behind backwards-compatibility aliases for *PolicyGraph and SampleBatch.

* [Java] Dynamic resource API in Java (ray-project#4824)

* Add default values for Wgym flags

* Fix import

* Fix issue when starting `raylet_monitor` (ray-project#4829)

* Refactor ID Serial 1: Separate ObjectID and TaskID from UniqueID (ray-project#4776)

* Enable BaseId.

* Change TaskID and make python test pass

* Remove unnecessary functions and fix test failure and change TaskID to
16 bytes.

* Java code change draft

* Refine

* Lint

* Update java/api/src/main/java/org/ray/api/id/TaskId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/BaseId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/BaseId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/ObjectId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Address comment

* Lint

* Fix SINGLE_PROCESS

* Fix comments

* Refine code

* Refine test

* Resolve conflict

* Fix bug in which actor classes are not exported multiple times. (ray-project#4838)

* Bump Ray master version to 0.8.0.dev0 (ray-project#4845)

* Add section to bump version of master branch and cleanup release docs (ray-project#4846)

* Fix import

* Export remote functions when first used and also fix bug in which rem… (ray-project#4844)

* Export remote functions when first used and also fix bug in which remote functions and actor classes are not exported from workers during subsequent ray sessions.

* Documentation update

* Fix tests.

* Fix grammar

* Update wheel versions in documentation to 0.8.0.dev0 and 0.7.0. (ray-project#4847)

* [tune] Later expansion of local_dir (ray-project#4806)

* [rllib] [RFC] Deprecate Python 2 / RLlib (ray-project#4832)

* Fix a typo in kubernetes yaml (ray-project#4872)

* Move global state API out of global_state object. (ray-project#4857)

* Install bazel in autoscaler development configs. (ray-project#4874)

* [tune] Fix up Ax Search and Examples (ray-project#4851)

* update Ax for cleaner API

* docs update

* [rllib] Update concepts docs and add "Building Policies in Torch/TensorFlow" section (ray-project#4821)

* wip

* fix index

* fix bugs

* todo

* add imports

* note on get ph

* note on get ph

* rename to building custom algs

* add rnn state info

* [rllib] Fix error getting kl when simple_optimizer: True in multi-agent PPO

* Replace ReturnIds with NumReturns in TaskInfo to reduce the size (ray-project#4854)

* Refine TaskInfo

* Fix

* Add a test to print task info size

* Lint

* Refine

* Update deps commits of opencensus to support building with bzl 0.25.x (ray-project#4862)

* Update deps to support bzl 2.5.x

* Fix

* Upgrade arrow to latest master (ray-project#4858)

* [tune] Auto-init Ray + default SearchAlg (ray-project#4815)

* Bump version from 0.8.0.dev0 to 0.7.1. (ray-project#4890)

* [rllib] Allow access to batches prior to postprocessing (ray-project#4871)

* [rllib] Fix Multidiscrete support (ray-project#4869)

* Refactor redis callback handling (ray-project#4841)

* Add CallbackReply

* Fix

* fix linting by format.sh

* Fix linting

* Address comments.

* Fix

* Initial high-level code structure of CoreWorker. (ray-project#4875)

* Drop duplicated string format (ray-project#4897)

This string format is unnecessary. java_worker_options has been appended to the commandline later.

* Refactor ID Serial 2: change all ID functions to `CamelCase` (ray-project#4896)

* Hotfix for change of from_random to FromRandom (ray-project#4909)

* [rllib] Fix documentation on custom policies (ray-project#4910)

* wip

* add docs

* lint

* todo sections

* fix doc

* [rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (ray-project#4894)

* fix torch extra out

* preserve setitem

* fix docs

* [tune] Pretty print params json in logger.py (ray-project#4903)

* [sgd] Distributed Training via PyTorch (ray-project#4797)

Implements distributed SGD using distributed PyTorch.

* [rllib] Rough port of DQN to build_tf_policy() pattern (ray-project#4823)

* fetching objects in parallel in _get_arguments_for_execution (ray-project#4775)

* [tune] Disallow setting resources_per_trial when it is already configured (ray-project#4880)

* disallow it

* import fix

* fix example

* fix test

* fix tests

* Update mock.py

* fix

* make less convoluted

* fix tests

* [rllib] Rename PolicyEvaluator => RolloutWorker (ray-project#4820)

* Fix local cluster yaml (ray-project#4918)

* [tune] Directional metrics for components (ray-project#4120) (ray-project#4915)

* [Core Worker] implement ObjectInterface and add test framework (ray-project#4899)

* [tune] Make PBT Quantile fraction configurable (ray-project#4912)

* Better organize ray_common module (ray-project#4898)

* Fix error

* Fix compute actions return value
  • Loading branch information
stefanpantic authored Jun 6, 2019
1 parent a86e198 commit 62c0a9e
Show file tree
Hide file tree
Showing 228 changed files with 5,402 additions and 3,085 deletions.
68 changes: 52 additions & 16 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cc_library(
"src/ray/raylet/mock_gcs_client.cc",
"src/ray/raylet/monitor_main.cc",
"src/ray/raylet/*_test.cc",
"src/ray/raylet/main.cc",
],
),
hdrs = glob([
Expand Down Expand Up @@ -105,6 +106,39 @@ cc_library(
],
)

cc_library(
name = "core_worker_lib",
srcs = glob(
[
"src/ray/core_worker/*.cc",
],
exclude = [
"src/ray/core_worker/*_test.cc",
],
),
hdrs = glob([
"src/ray/core_worker/*.h",
]),
copts = COPTS,
deps = [
":ray_common",
":ray_util",
":raylet_lib",
],
)

# This test is run by src/ray/test/run_core_worker_tests.sh
cc_binary(
name = "core_worker_test",
srcs = ["src/ray/core_worker/core_worker_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
":gcs",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "lineage_cache_test",
srcs = ["src/ray/raylet/lineage_cache_test.cc"],
Expand Down Expand Up @@ -247,16 +281,13 @@ cc_library(
name = "ray_util",
srcs = glob(
[
"src/ray/*.cc",
"src/ray/util/*.cc",
],
exclude = [
"src/ray/util/logging_test.cc",
"src/ray/util/signal_test.cc",
"src/ray/util/*_test.cc",
],
),
hdrs = glob([
"src/ray/*.h",
"src/ray/util/*.h",
]),
copts = COPTS,
Expand All @@ -272,23 +303,28 @@ cc_library(

cc_library(
name = "ray_common",
srcs = [
"src/ray/common/client_connection.cc",
"src/ray/common/common_protocol.cc",
],
hdrs = [
"src/ray/common/client_connection.h",
"src/ray/common/common_protocol.h",
],
srcs = glob(
[
"src/ray/common/*.cc",
],
exclude = [
"src/ray/common/*_test.cc",
],
),
hdrs = glob(
[
"src/ray/common/*.h",
],
),
copts = COPTS,
includes = [
"src/ray/gcs/format",
],
deps = [
":gcs_fbs",
":node_manager_fbs",
":ray_util",
"@boost//:asio",
"@plasma//:plasma_client",
],
)

Expand Down Expand Up @@ -432,7 +468,7 @@ cc_binary(
srcs = [
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h",
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc",
"src/ray/id.h",
"src/ray/common/id.h",
"src/ray/raylet/raylet_client.h",
"src/ray/util/logging.h",
"@bazel_tools//tools/jdk:jni_header",
Expand Down Expand Up @@ -637,8 +673,8 @@ genrule(
cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done &&
mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ &&
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
done &&
echo $$WORK_DIR > $@
""",
Expand Down
6 changes: 6 additions & 0 deletions bazel/BUILD.plasma
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ cc_library(
name = "arrow",
srcs = [
"cpp/src/arrow/buffer.cc",
"cpp/src/arrow/io/interfaces.cc",
"cpp/src/arrow/memory_pool.cc",
"cpp/src/arrow/status.cc",
"cpp/src/arrow/util/io-util.cc",
"cpp/src/arrow/util/logging.cc",
"cpp/src/arrow/util/memory.cc",
"cpp/src/arrow/util/string_builder.cc",
"cpp/src/arrow/util/thread-pool.cc",
],
hdrs = [
Expand All @@ -42,6 +44,7 @@ cc_library(
"cpp/src/arrow/util/logging.h",
"cpp/src/arrow/util/macros.h",
"cpp/src/arrow/util/memory.h",
"cpp/src/arrow/util/stl.h",
"cpp/src/arrow/util/string_builder.h",
"cpp/src/arrow/util/string_view.h",
"cpp/src/arrow/util/thread-pool.h",
Expand All @@ -53,6 +56,9 @@ cc_library(
"cpp/src/arrow/vendored/xxhash/xxhash.h",
],
strip_include_prefix = "cpp/src",
deps = [
"@boost//:filesystem",
],
)

cc_library(
Expand Down
30 changes: 15 additions & 15 deletions bazel/ray_deps_setup.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

def ray_deps_setup():
RULES_JVM_EXTERNAL_TAG = "1.2"

RULES_JVM_EXTERNAL_SHA = "e5c68b87f750309a79f59c2b69ead5c3221ffa54ff9496306937bfa1c9c8c86b"

http_archive(
name = "rules_jvm_external",
sha256 = RULES_JVM_EXTERNAL_SHA,
Expand All @@ -18,72 +18,72 @@ def ray_deps_setup():
strip_prefix = "bazel-common-f1115e0f777f08c3cdb115526c4e663005bec69b",
url = "https://github.com/google/bazel-common/archive/f1115e0f777f08c3cdb115526c4e663005bec69b.zip",
)

BAZEL_SKYLIB_TAG = "0.6.0"

http_archive(
name = "bazel_skylib",
strip_prefix = "bazel-skylib-%s" % BAZEL_SKYLIB_TAG,
url = "https://github.com/bazelbuild/bazel-skylib/archive/%s.tar.gz" % BAZEL_SKYLIB_TAG,
)

git_repository(
name = "com_github_checkstyle_java",
commit = "85f37871ca03b9d3fee63c69c8107f167e24e77b",
remote = "https://github.com/ruifangChen/checkstyle_java",
)

git_repository(
name = "com_github_nelhage_rules_boost",
commit = "5171b9724fbb39c5fdad37b9ca9b544e8858d8ac",
remote = "https://github.com/ray-project/rules_boost",
)

git_repository(
name = "com_github_google_flatbuffers",
commit = "63d51afd1196336a7d1f56a988091ef05deb1c62",
remote = "https://github.com/google/flatbuffers.git",
)

git_repository(
name = "com_google_googletest",
commit = "3306848f697568aacf4bcca330f6bdd5ce671899",
remote = "https://github.com/google/googletest",
)

git_repository(
name = "com_github_gflags_gflags",
remote = "https://github.com/gflags/gflags.git",
tag = "v2.2.2",
)

new_git_repository(
name = "com_github_google_glog",
build_file = "@//bazel:BUILD.glog",
commit = "5c576f78c49b28d89b23fbb1fc80f54c879ec02e",
remote = "https://github.com/google/glog",
)

new_git_repository(
name = "plasma",
build_file = "@//bazel:BUILD.plasma",
commit = "d00497b38be84fd77c40cbf77f3422f2a81c44f9",
commit = "9fcc12fc094b85ec2e3e9798bae5c8151d14df5e",
remote = "https://github.com/apache/arrow",
)

new_git_repository(
name = "cython",
build_file = "@//bazel:BUILD.cython",
commit = "49414dbc7ddc2ca2979d6dbe1e44714b10d72e7e",
remote = "https://github.com/cython/cython",
)

http_archive(
name = "io_opencensus_cpp",
strip_prefix = "opencensus-cpp-3aa11f20dd610cb8d2f7c62e58d1e69196aadf11",
urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/3aa11f20dd610cb8d2f7c62e58d1e69196aadf11.zip"],
)

# OpenCensus depends on Abseil so we have to explicitly pull it in.
# This is how diamond dependencies are prevented.
git_repository(
Expand All @@ -96,7 +96,7 @@ def ray_deps_setup():
http_archive(
name = "com_github_jupp0r_prometheus_cpp",
strip_prefix = "prometheus-cpp-master",

# TODO(qwang): We should use the repository of `jupp0r` here when this PR
# `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged.
urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"],
Expand Down
4 changes: 2 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ pushd "$BUILD_DIR"
# generated from https://github.com/ray-project/arrow-build from
# the commit listed in the command.
$PYTHON_EXECUTABLE -m pip install \
--target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.12.0.RAY \
--find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/ca1fa51f0901f5a4298f0e4faea00f24e5dd7bb7/index.html
--target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.14.0.RAY \
--find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/9f35817b35f9d0614a736a497d70de2cf07fed52/index.html
export PYTHON_BIN_PATH="$PYTHON_EXECUTABLE"

if [ "$RAY_BUILD_JAVA" == "YES" ]; then
Expand Down
23 changes: 1 addition & 22 deletions ci/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \
######################## SGD TESTS #################################

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps --tune
python -m pytest /ray/python/ray/experimental/sgd/tests
13 changes: 11 additions & 2 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_checkpoint_restore.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_policy_evaluator.py
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_rollout_worker.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_nested_spaces.py
Expand Down Expand Up @@ -390,7 +390,16 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
Expand Down
3 changes: 2 additions & 1 deletion ci/long_running_tests/workloads/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

pbt = PopulationBasedTraining(
time_attr="training_iteration",
reward_attr="episode_reward_mean",
metric="episode_reward_mean",
mode="max",
perturbation_interval=10,
hyperparam_mutations={
"lr": [0.1, 0.01, 0.001, 0.0001],
Expand Down
4 changes: 4 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"torch",
"torch.distributed",
"torch.nn",
"torch.utils.data",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
Expand Down
48 changes: 48 additions & 0 deletions doc/source/distributed_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Distributed Training (Experimental)
===================================


Ray includes abstractions for distributed model training that integrate with
deep learning frameworks, such as PyTorch.

Ray Train is built on top of the Ray task and actor abstractions to provide
seamless integration into existing Ray applications.

PyTorch Interface
-----------------

To use Ray Train with PyTorch, pass model and data creator functions to the
``ray.experimental.sgd.pytorch.PyTorchTrainer`` class.
To drive the distributed training, ``trainer.train()`` can be called
repeatedly.

.. code-block:: python
model_creator = lambda config: YourPyTorchModel()
data_creator = lambda config: YourTrainingSet(), YourValidationSet()
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config={"lr": 1e-4},
num_replicas=2,
resources_per_replica=Resources(num_gpus=1),
batch_size=16,
backend="auto")
for i in range(NUM_EPOCHS):
trainer.train()
Under the hood, Ray Train will create *replicas* of your model
(controlled by ``num_replicas``) which are each managed by a worker.
Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``),
which allows training of lage models across multiple GPUs.
The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model.

The full documentation for ``PyTorchTrainer`` is as follows:

.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer
:members:

.. automethod:: __init__
Loading

0 comments on commit 62c0a9e

Please sign in to comment.