Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For release 2025.01.00 #102

Merged
merged 108 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
b7b84e0
updating 4d handling
oelbert Aug 27, 2024
bf4e2f2
debug 4d test data
oelbert Aug 27, 2024
dfc4e5f
more iter
oelbert Aug 27, 2024
6ab1dd5
moving ser_to_nc here
oelbert Aug 27, 2024
7bcce4b
merge develop
oelbert Sep 4, 2024
2d7062a
updating datatype in translate test
oelbert Sep 5, 2024
d44551e
typing works
oelbert Sep 5, 2024
ed3d431
fix dict, lint
oelbert Sep 5, 2024
2225fd9
remove empty line
oelbert Sep 5, 2024
f3cf32d
change from 4d to Nd
oelbert Sep 30, 2024
f894a49
Merge branch 'develop' into feature/4d_data
FlorianDeconinck Sep 30, 2024
ed4ddd4
Expose `k_start` and `k_end` automatically for any FrozenStencil
Oct 7, 2024
5b09a67
Fix k_start + utest
Oct 7, 2024
b0b2940
lint
Oct 7, 2024
0c7c902
Fix for 2d stencils
Oct 7, 2024
b5a7fa7
Merge pull request #78 from FlorianDeconinck/feature/K_axis_bounds_in…
fmalatino Oct 7, 2024
4b8d4b9
Add threshold overrides to the multimodal metric
Oct 7, 2024
0f1644c
Merge pull request #71 from oelbert/feature/4d_data
FlorianDeconinck Oct 9, 2024
720149a
Always report results, add summary with one liners
Oct 9, 2024
c9da47b
Merge branch 'develop' into feature/multimodal_metric_threshold_override
FlorianDeconinck Oct 9, 2024
020a259
Remove "mmr" from the keys
Oct 9, 2024
75f886f
Merge remote-tracking branch 'fdeconinck/feature/multimodal_metric_th…
Oct 9, 2024
d59918b
README in testing
Oct 10, 2024
4bdc914
Better Latex (?)
Oct 10, 2024
bdb3820
Better Latex (?)
Oct 10, 2024
e17539d
fixing a typo that breaks bools in translate tests (#80)
oelbert Oct 10, 2024
0fc563e
Merge branch 'develop' into feature/multimodal_metric_threshold_override
FlorianDeconinck Oct 10, 2024
f8105dd
Fix summary filename
FlorianDeconinck Oct 11, 2024
6389638
Merge remote-tracking branch 'fdeconinck/feature/multimodal_metric_th…
FlorianDeconinck Oct 11, 2024
5ac067f
Fix report, filename
FlorianDeconinck Oct 13, 2024
8870e46
Fix choosing right absolute difference for F32
FlorianDeconinck Oct 17, 2024
4d6d96c
Make robust for NaN value
FlorianDeconinck Oct 23, 2024
aed5912
Merge pull request #79 from FlorianDeconinck/feature/multimodal_metri…
FlorianDeconinck Oct 23, 2024
001d2bd
Detect when array have different dimensions, if only one dimension, c…
FlorianDeconinck Oct 24, 2024
eaf1d20
Lint
FlorianDeconinck Oct 24, 2024
489aab1
Add rank 0 to the data
FlorianDeconinck Oct 24, 2024
2af8dfb
Check data exists for rank, skip & print if not
FlorianDeconinck Oct 24, 2024
ce38ce0
Fix bad logic on skip test for parallel
FlorianDeconinck Oct 31, 2024
05952aa
Verbose exported names
FlorianDeconinck Nov 12, 2024
3347431
Make boilerplate calls more nimble
FlorianDeconinck Nov 12, 2024
1eda108
New option: `which_savepoint`
FlorianDeconinck Nov 12, 2024
b1b3ac0
QOL for mypy/flak8 type hints
FlorianDeconinck Nov 12, 2024
264da4e
Merge pull request #83 from FlorianDeconinck/fix/boilerplate
FlorianDeconinck Nov 13, 2024
153138e
Merge pull request #82 from FlorianDeconinck/feature/serialbox_netcdf…
FlorianDeconinck Nov 13, 2024
53b273b
Merge pull request #84 from FlorianDeconinck/qol/translate_test_which…
FlorianDeconinck Nov 14, 2024
88129fc
Add SECONDS_PER_DAY as a constants following mixed precision standards
FlorianDeconinck Dec 2, 2024
c436b0b
Lint
FlorianDeconinck Dec 2, 2024
3c1ee68
Merge pull request #86 from FlorianDeconinck/feature/seconds_per_day_…
FlorianDeconinck Dec 2, 2024
9efb5f4
Cleanups in dace orchestration
Dec 4, 2024
c7d6c4f
Rename program -> dace_program
Dec 4, 2024
60a8f59
Merge pull request #87 from romanc/romanc/cleanups-orchestartion
FlorianDeconinck Dec 5, 2024
ce3ac7e
Make sure all constants adhere to the floating point precision set by…
FlorianDeconinck Dec 9, 2024
502486f
Move `is_float` to `dsl.typing`
FlorianDeconinck Dec 10, 2024
a13776f
Move Quantity to sub-directory + breakout the subcomponent
FlorianDeconinck Dec 10, 2024
937417b
Fix tests
FlorianDeconinck Dec 10, 2024
45c3180
Lint
FlorianDeconinck Dec 10, 2024
0330cdb
Remove `cp.ndarray` since cupy is optional
FlorianDeconinck Dec 10, 2024
18b2f3f
Restore workaround for optional cupy
FlorianDeconinck Dec 10, 2024
7076740
"GFS" -> "UFS"
FlorianDeconinck Dec 10, 2024
a8a7c85
Cupy trick for metadata
FlorianDeconinck Dec 10, 2024
a7ee68f
Add comments for constant explanation
FlorianDeconinck Dec 11, 2024
28e2375
Describe 64/32-bit FloatFields
FlorianDeconinck Dec 11, 2024
cf4c2ce
Merge pull request #88 from FlorianDeconinck/fix/F32/Constants
FlorianDeconinck Dec 11, 2024
8daf5bd
Make sure the `make_storage_data` respects the array dtype.
FlorianDeconinck Dec 11, 2024
9faa405
Fix logic for MultiModal metric and verbose it
FlorianDeconinck Dec 11, 2024
7c03e92
Merge pull request #90 from FlorianDeconinck/feature/fixed_precision_…
FlorianDeconinck Dec 11, 2024
359812a
Merge branch 'develop' into fix/translate_test_storage_precision
FlorianDeconinck Dec 11, 2024
75b4741
Added an MPI all_reduce for quantities based on SUM operation to comm…
gmao-ckung Dec 11, 2024
4c8632c
linted
gmao-ckung Dec 11, 2024
a2fac9f
Add initial skeleton of pytest test for all reduce
gmao-ckung Dec 13, 2024
38ee6a6
Merge pull request #91 from FlorianDeconinck/fix/translate_test_stora…
FlorianDeconinck Dec 13, 2024
8c5b5d5
Added assertion tests for 1, 2 and 3D quantities passed through mpi_a…
gmao-ckung Dec 13, 2024
fb4e740
Linted
gmao-ckung Dec 13, 2024
34f82fb
Added pytest.mark to skip test if mpi4py isn't available
gmao-ckung Dec 13, 2024
b4a6a54
lint changes
gmao-ckung Dec 16, 2024
f5ce883
Addressed PR comments and added additional CPU backends to unit test
gmao-ckung Dec 16, 2024
2e41349
Merge branch 'feature/mpi_allreduce_sum' of https://github.com/NOAA-G…
gmao-ckung Dec 16, 2024
2e669db
Added setters for various Quantity properties to enable setting of Qu…
gmao-ckung Dec 18, 2024
fd2fa97
Added function in QuantityMetadata class that allows copying of Metad…
gmao-ckung Dec 19, 2024
ad19be3
Expose all SG metric terms in grid_data
FlorianDeconinck Dec 20, 2024
76f53c8
Merge pull request #93 from FlorianDeconinck/feature/minor/add_all_co…
FlorianDeconinck Dec 20, 2024
cc620c6
Add `Allreduce` and all MPI OP
FlorianDeconinck Dec 22, 2024
0e8089e
Update utest
FlorianDeconinck Dec 22, 2024
2188c75
Fix `local_comm`
FlorianDeconinck Dec 22, 2024
f8cc2ce
Fix utest
FlorianDeconinck Dec 22, 2024
7ad271f
Enforce `comm_abc.Comm` into Communicator
FlorianDeconinck Dec 22, 2024
07cd0f3
Fix `comm` object in serial utest
FlorianDeconinck Dec 22, 2024
224e6e2
Lint + `MPIComm` on testing architecture
FlorianDeconinck Dec 22, 2024
312b492
Merge branch 'develop' into feature/mpi_allreduce_sum
FlorianDeconinck Dec 22, 2024
f99914a
Make sure the correct allocator backend is used for Quantities
FlorianDeconinck Dec 27, 2024
760578c
Add in_place option for Allreduce
FlorianDeconinck Dec 30, 2024
a8a2e73
Merge pull request #95 from FlorianDeconinck/fix/boilerplate_on_gpu
FlorianDeconinck Jan 2, 2025
c758ffb
Merge branch 'develop' into feature/mpi_allreduce_sum
FlorianDeconinck Jan 7, 2025
9f5e50c
Merge pull request #92 from NOAA-GFDL/feature/mpi_allreduce_sum
FlorianDeconinck Jan 7, 2025
4e06ee8
Cleanup ndsl/dsl/dace/utils.py (#96)
romanc Jan 7, 2025
94b35d0
Merge branch 'develop' into refactor/quantity
FlorianDeconinck Jan 7, 2025
fcfb058
Fix merge
FlorianDeconinck Jan 7, 2025
1c7c30c
Merge pull request #89 from FlorianDeconinck/refactor/quantity
FlorianDeconinck Jan 14, 2025
1fa8d79
Hotfix for grid generation use of mpi operators
fmalatino Jan 16, 2025
a75a1d7
Merge pull request #98 from fmalatino/fix/reductop
FlorianDeconinck Jan 16, 2025
3252ec7
Merge examples/mpi/.gitignore into top-level .gitignore
romanc Jan 20, 2025
04ecf87
Remove hard-coded __version__ numbers
romanc Jan 20, 2025
4881b34
Fixing a bunch of typos
romanc Jan 20, 2025
1da613d
Merge pull request #99 from romanc/romanc/cleanup-gitignore
fmalatino Jan 21, 2025
a51daf3
hotfix netcdf version for dockerfiles
oelbert Jan 22, 2025
acb8c0d
Merge pull request #100 from oelbert/develop
fmalatino Jan 22, 2025
3f77863
Updated version number in setup.py to reflect new release, 2025.01.00
fmalatino Jan 23, 2025
b7db259
Merge pull request #101 from fmalatino/rc-2025.01.00
fmalatino Jan 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ driver/examples/comm
20*-*-*-*-*-*.json
*.pkl

# example outputs
examples/mpi/output

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
4 changes: 2 additions & 2 deletions examples/NDSL/03_orchestration_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
")\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.dsl.typing import FloatField, Float\n",
"from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu"
"from ndsl.boilerplate import get_factories_single_tile_orchestrated"
]
},
{
Expand Down Expand Up @@ -126,7 +126,7 @@
" tile_size = (3, 3, 3)\n",
"\n",
" # Setup\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated_cpu(\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated(\n",
" nx=tile_size[0],\n",
" ny=tile_size[1],\n",
" nz=tile_size[2],\n",
Expand Down
1 change: 0 additions & 1 deletion examples/mpi/.gitignore

This file was deleted.

2 changes: 1 addition & 1 deletion ndsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .dsl.dace.utils import (
ArrayReport,
DaCeProgress,
MaxBandwithBenchmarkProgram,
MaxBandwidthBenchmarkProgram,
StorageReport,
)
from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater
Expand Down
18 changes: 10 additions & 8 deletions ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TileCommunicator,
TilePartitioner,
)
from ndsl.optional_imports import cupy as cp


def _get_factories(
Expand Down Expand Up @@ -74,36 +75,37 @@ def _get_factories(

grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm)
stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)
quantity_factory = QuantityFactory(sizer, np)
quantity_factory = QuantityFactory(
sizer, cp if stencil_config.is_gpu_backend else np
)

return stencil_factory, quantity_factory


def get_factories_single_tile_orchestrated_cpu(
nx, ny, nz, nhalo
def get_factories_single_tile_orchestrated(
nx, ny, nz, nhalo, on_cpu: bool = True
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="dace:cpu",
backend="dace:cpu" if on_cpu else "dace:gpu",
orchestration=DaCeOrchestration.BuildAndRun,
topology="tile",
)


def get_factories_single_tile_numpy(
nx, ny, nz, nhalo
def get_factories_single_tile(
nx, ny, nz, nhalo, backend: str = "numpy"
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for Numpy, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="numpy",
backend=backend,
orchestration=DaCeOrchestration.Python,
topology="tile",
)
4 changes: 2 additions & 2 deletions ndsl/comm/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def send_view(self, quantity: Quantity, n_points: int):
return self._view(quantity, n_points, interior=True)

def recv_view(self, quantity: Quantity, n_points: int):
"""Return a sliced view of points which should be recieved at this boundary.
"""Return a sliced view of points which should be received at this boundary.

Args:
quantity: quantity for which to return a slice
Expand All @@ -37,7 +37,7 @@ def recv_view(self, quantity: Quantity, n_points: int):
return self._view(quantity, n_points, interior=False)

def send_slice(self, specification: QuantityHaloSpec) -> Tuple[slice]:
"""Return the index slices which shoud be sent at this boundary.
"""Return the index slices which should be sent at this boundary.

Args:
specification: data specifications for the halo. Including shape
Expand Down
12 changes: 9 additions & 3 deletions ndsl/comm/caching_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ndsl.comm.comm_abc import Comm, Request
from ndsl.comm.comm_abc import Comm, ReductionOperator, Request


T = TypeVar("T")
Expand Down Expand Up @@ -147,9 +147,12 @@ def Split(self, color, key) -> "CachingCommReader":
new_data = self._data.get_split()
return CachingCommReader(data=new_data)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
return self._data.get_generic_obj()

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce")

@classmethod
def load(cls, file: BinaryIO) -> "CachingCommReader":
data = CachingCommData.load(file)
Expand Down Expand Up @@ -229,7 +232,10 @@ def Split(self, color, key) -> "CachingCommWriter":
def dump(self, file: BinaryIO):
self._data.dump(file)

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
result = self._comm.allreduce(sendobj, op)
self._data.generic_obj_buffers.append(copy.deepcopy(result))
return result

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommWriter.Allreduce")
29 changes: 28 additions & 1 deletion ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import abc
import enum
from typing import List, Optional, TypeVar


T = TypeVar("T")


@enum.unique
class ReductionOperator(enum.Enum):
OP_NULL = enum.auto()
MAX = enum.auto()
MIN = enum.auto()
SUM = enum.auto()
PROD = enum.auto()
LAND = enum.auto()
BAND = enum.auto()
LOR = enum.auto()
BOR = enum.auto()
LXOR = enum.auto()
BXOR = enum.auto()
MAXLOC = enum.auto()
MINLOC = enum.auto()
REPLACE = enum.auto()
NO_OP = enum.auto()


class Request(abc.ABC):
@abc.abstractmethod
def wait(self):
Expand Down Expand Up @@ -69,5 +89,12 @@ def Split(self, color, key) -> "Comm":
...

@abc.abstractmethod
def allreduce(self, sendobj: T, op=None) -> T:
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
...

@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
...

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
...
83 changes: 70 additions & 13 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import ndsl.constants as constants
from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer
from ndsl.comm.boundary import Boundary
from ndsl.comm.comm_abc import Comm as CommABC
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner
from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater
from ndsl.performance.timer import NullTimer, Timer
Expand Down Expand Up @@ -44,7 +46,11 @@ def to_numpy(array, dtype=None) -> np.ndarray:

class Communicator(abc.ABC):
def __init__(
self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None
self,
comm: CommABC,
partitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
):
self.comm = comm
self.partitioner: Partitioner = partitioner
Expand All @@ -61,7 +67,7 @@ def tile(self) -> "TileCommunicator":
@abc.abstractmethod
def from_layout(
cls,
comm,
comm: CommABC,
layout: Tuple[int, int],
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand Down Expand Up @@ -93,17 +99,63 @@ def _device_synchronize():
# this is a method so we can profile it separately from other device syncs
device_synchronize()

def _create_all_reduce_quantity(
self, input_metadata: QuantityMetadata, input_data
) -> Quantity:
"""Create a Quantity for all_reduce data and metadata"""
all_reduce_quantity = Quantity(
input_data,
dims=input_metadata.dims,
units=input_metadata.units,
origin=input_metadata.origin,
extent=input_metadata.extent,
gt4py_backend=input_metadata.gt4py_backend,
allow_mismatch_float_precision=False,
)
return all_reduce_quantity

def all_reduce(
self,
input_quantity: Quantity,
op: ReductionOperator,
output_quantity: Quantity = None,
):
reduced_quantity_data = self.comm.allreduce(input_quantity.data, op)
if output_quantity is None:
all_reduce_quantity = self._create_all_reduce_quantity(
input_quantity.metadata, reduced_quantity_data
)
return all_reduce_quantity
else:
if output_quantity.data.shape != input_quantity.data.shape:
raise TypeError("Shapes not matching")

input_quantity.metadata.duplicate_metadata(output_quantity.metadata)

output_quantity.data = reduced_quantity_data

def all_reduce_per_element(
self,
input_quantity: Quantity,
output_quantity: Quantity,
op: ReductionOperator,
):
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def all_reduce_per_element_in_place(
self, quantity: Quantity, op: ReductionOperator
):
self.comm.Allreduce_inplace(quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
self.comm.Scatter(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Scatter(send, recv, **kwargs)

def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
numpy_module.zeros, recvbuf
) as recv:
self.comm.Gather(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Gather(send, recv, **kwargs)

def scatter(
self,
Expand Down Expand Up @@ -252,7 +304,7 @@ def gather_state(self, send_state=None, recv_state=None, transfer_type=None):

Args:
send_state: the model state to be sent containing the subtile data
recv_state: the pre-allocated state in which to recieve the full tile
recv_state: the pre-allocated state in which to receive the full tile
state. Only variables which are scattered will be written to.
Returns:
recv_state: on the root rank, the state containing the entire tile
Expand Down Expand Up @@ -288,7 +340,7 @@ def scatter_state(self, send_state=None, recv_state=None):
Args:
send_state: the model state to be sent containing the entire tile,
required only from the root rank
recv_state: the pre-allocated state in which to recieve the scattered
recv_state: the pre-allocated state in which to receive the scattered
state. Only variables which are scattered will be written to.
Returns:
rank_state: the state corresponding to this rank's subdomain
Expand Down Expand Up @@ -709,7 +761,7 @@ class CubedSphereCommunicator(Communicator):

def __init__(
self,
comm,
comm: CommABC,
partitioner: CubedSpherePartitioner,
force_cpu: bool = False,
timer: Optional[Timer] = None,
Expand All @@ -722,6 +774,11 @@ def __init__(
force_cpu: Force all communication to go through central memory.
timer: Time communication operations.
"""
if not issubclass(type(comm), CommABC):
raise TypeError(
"Communicator needs to be instantiated with communication subsystem"
f" derived from `comm_abc.Comm`, got {type(comm)}."
)
if comm.Get_size() != partitioner.total_ranks:
raise ValueError(
f"was given a partitioner for {partitioner.total_ranks} ranks but a "
Expand Down
10 changes: 8 additions & 2 deletions ndsl/comm/local_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,14 @@ def Split(self, color, key):
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(self, sendobj, op=None) -> Any:
def allreduce(self, sendobj, op=None, recvobj=None) -> Any:
raise NotImplementedError(
"sendrecv fundamentally cannot be written for LocalComm, "
"allreduce fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)

def Allreduce(self, sendobj, recvobj, op) -> Any:
raise NotImplementedError(
"Allreduce fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)
Loading
Loading