Source code for rofunc.learning.utils.utils
# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import sys
import time
from typing import Optional
import cv2
import numpy as np
import torch
[docs]def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
"""
Set the seed for the random number generators
:param seed: The seed to set. Is None, a random seed will be generated (default: ``None``)
:param deterministic: Whether PyTorch is configured to use deterministic algorithms (default: ``False``).
"""
# generate a random seed
if seed is None:
try:
seed = int.from_bytes(os.urandom(4), byteorder=sys.byteorder)
except NotImplementedError:
seed = int(time.time() * 1000)
seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cv2.setRNGSeed(seed)
if deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# On CUDA 10.1, set environment variable CUDA_LAUNCH_BLOCKING=1
# On CUDA 10.2 or later, set environment variable CUBLAS_WORKSPACE_CONFIG=:16:8 or CUBLAS_WORKSPACE_CONFIG=:4096:2
# logger.warning("PyTorch/cuDNN deterministic algorithms are enabled. This may affect performance")
return seed
[docs]def to_device(x, device):
if isinstance(x, torch.Tensor):
return x.to(device)
elif isinstance(x, dict):
return {k: to_device(v, device) for k, v in x.items()}
elif isinstance(x, list):
return [to_device(v, device) for v in x]
else:
return x