Skip to content

Commit

Permalink
Fix IPEX support and add XPU device to device_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Jan 31, 2024
1 parent 2ca4d0c commit a6a2b5a
Show file tree
Hide file tree
Showing 27 changed files with 248 additions and 245 deletions.
4 changes: 2 additions & 2 deletions XTI_hijack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from library.ipex_interop import init_ipex

from library.device_utils import init_ipex
init_ipex()

from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

Expand Down
6 changes: 2 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import toml

from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

import torch
from library.device_utils import init_ipex, clean_memory
init_ipex()

from accelerate.utils import set_seed
Expand Down
5 changes: 4 additions & 1 deletion finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
from PIL import Image
from tqdm import tqdm
import numpy as np

import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

Expand Down
5 changes: 4 additions & 1 deletion finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()

from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util
from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

Expand Down
6 changes: 4 additions & 2 deletions finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import numpy as np
from PIL import Image
import cv2

import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()

from torchvision import transforms

import library.model_util as model_util
import library.train_util as train_util

from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

IMAGE_TRANSFORMS = transforms.Compose(
Expand Down
6 changes: 2 additions & 4 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,9 @@

import diffusers
import numpy as np
import torch

from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

import torch
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()

import torchvision
Expand Down
29 changes: 29 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
except Exception:
HAS_MPS = False

try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False


def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()

Expand All @@ -26,9 +34,30 @@ def clean_memory():
def get_preferred_device() -> torch.device:
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device

def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)
Loading

0 comments on commit a6a2b5a

Please sign in to comment.