Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use GE backend for graph mode #296

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS
from mindyolo.models import create_model
Expand Down Expand Up @@ -53,6 +54,8 @@ def get_parser_infer(parents=None):
def set_default_infer(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down
17 changes: 10 additions & 7 deletions mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import numpy as np

import mindspore as ms
from mindspore import context, ops, Tensor, nn
from mindspore import ops, Tensor, nn
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore import ParallelMode
from mindspore._c_expression import ms_ctx_param

from mindyolo.utils import logger

Expand All @@ -21,22 +22,24 @@ def set_seed(seed=2):

def set_default(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
device_id = int(os.getenv("DEVICE_ID", 0))
context.set_context(device_id=device_id)
ms.set_context(device_id=device_id)
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
context.set_context(enable_graph_kernel=True)
ms.set_context(enable_graph_kernel=True)
# Set Parallel
if args.is_parallel:
init()
args.rank, args.rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
ms.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
else:
args.rank, args.rank_size = 0, 1
# Set Default
args.total_batch_size = args.per_batch_size * args.rank_size
args.sync_bn = args.sync_bn and context.get_context("device_target") == "Ascend" and args.rank_size > 1
args.sync_bn = args.sync_bn and ms.get_context("device_target") == "Ascend" and args.rank_size > 1
args.accumulate = max(1, np.round(args.nbs / args.total_batch_size)) if args.auto_accumulate else args.accumulate
# optimizer
args.optimizer.warmup_epochs = args.optimizer.get("warmup_epochs", 0)
Expand Down
3 changes: 3 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import mindspore as ms
from mindspore import Tensor, context, nn, ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS, COCODataset, create_loader
from mindyolo.models.model_factory import create_model
Expand Down Expand Up @@ -71,6 +72,8 @@ def get_parser_test(parents=None):
def set_default_test(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down
Loading