Source code for rofunc.learning.utils.utils

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import sys
import time
from typing import Optional

import cv2
import numpy as np
import torch


[docs]def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int: """ Set the seed for the random number generators :param seed: The seed to set. Is None, a random seed will be generated (default: ``None``) :param deterministic: Whether PyTorch is configured to use deterministic algorithms (default: ``False``). """ # generate a random seed if seed is None: try: seed = int.from_bytes(os.urandom(4), byteorder=sys.byteorder) except NotImplementedError: seed = int(time.time() * 1000) seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) cv2.setRNGSeed(seed) if deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # On CUDA 10.1, set environment variable CUDA_LAUNCH_BLOCKING=1 # On CUDA 10.2 or later, set environment variable CUBLAS_WORKSPACE_CONFIG=:16:8 or CUBLAS_WORKSPACE_CONFIG=:4096:2 # logger.warning("PyTorch/cuDNN deterministic algorithms are enabled. This may affect performance") return seed
[docs]def to_device(x, device): if isinstance(x, torch.Tensor): return x.to(device) elif isinstance(x, dict): return {k: to_device(v, device) for k, v in x.items()} elif isinstance(x, list): return [to_device(v, device) for v in x] else: return x