How to give them CUDA

Set up num_gpus=1

import ray
import torch
import os
import torch
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Device Count:", torch.cuda.device_count())
print("Current CUDA Device:", torch.cuda.current_device())
print("CUDA Device Name:", torch.cuda.get_device_name(torch.cuda.current_device()))

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

ray.init(num_gpus=1)

@ray.remote(num_gpus=1)
def check_cuda():
    print(os.environ.get("CUDA_VISIBLE_DEVICES"))
    return torch.cuda.is_available(), torch.cuda.device_count()

print(ray.get(check_cuda.remote()))
    @staticmethod
    @ray.remote(num_gpus=1)
    def match_in_parallel(descriptor, mutual_matching, inlier_search, num_keypts, 
                          des_r, src_pts, tgt_pts, src_pcd_raw, tgt_pcd_raw, dataset_name, cfg):
        # Keypoint Sampling
        s_pts_flipped, t_pts_flipped = src_pts[None].transpose(1, 2).contiguous(), tgt_pts[None].transpose(1,2).contiguous()
        s_fps_idx = pnt2.furthest_point_sample(src_pts[None], num_keypts)
        t_fps_idx = pnt2.furthest_point_sample(tgt_pts[None], num_keypts)
        kpts1 = pnt2.gather_operation(s_pts_flipped, s_fps_idx).transpose(1, 2).contiguous()
        kpts2 = pnt2.gather_operation(t_pts_flipped, t_fps_idx).transpose(1, 2).contiguous()

        # Compute descriptors
        src = descriptor(src_pcd_raw[None], kpts1, des_r, dataset_name)
        tgt = descriptor(tgt_pcd_raw[None], kpts2, des_r, dataset_name)
        src_des, src_equi, s_R = src['desc'], src['equi'], src['R']
        tgt_des, tgt_equi, t_R = tgt['desc'], tgt['equi'], tgt['R']

        # Mutual Matching
        s_mids, t_mids = mutual_matching(src_des, tgt_des)

        ss_kpts = kpts1[0, s_mids]
        ss_equi = src_equi[s_mids]
        ss_R = s_R[s_mids]
        tt_kpts = kpts2[0, t_mids]
        tt_equi = tgt_equi[t_mids]
        tt_R = t_R[t_mids]


        ind = inlier_search(ss_equi[:, :, 1:cfg.patch.ele_n - 1],
                            tt_equi[:, :, 1:cfg.patch.ele_n - 1])

        # Recovery pose
        angle = ind * 2 * np.pi / cfg.patch.azi_n + 1e-6
        angle_axis = torch.zeros_like(ss_kpts)
        angle_axis[:, -1] = 1
        angle_axis = angle_axis * angle[:, None]
        azi_R = Convert.axis_angle_to_rotation_matrix(angle_axis)

        R = tt_R @ azi_R @ ss_R.transpose(-1, -2)
        t = tt_kpts - (R @ ss_kpts.unsqueeze(-1)).squeeze()

        return R, t, ss_kpts, tt_kpts


하지만 직렬화 오버헤드가 문제가 될 수 있다.

직렬화 오버헤드란?

직렬화(Serialization)는 객체를 저장하거나 네트워크로 전송하기 위해 바이트(byte) 형태로 변환하는 과정이다. 반대로 역직렬화(Deserialization)는 바이트 형태 데이터를 다시 원래 객체로 변환하는 과정이다.

직렬화 오버헤드(Serialization Overhead)

데이터를 직렬화하는 데 걸리는 추가적인 연산 비용을 의미한다. Ray에서 remote()를 호출할 때, 객체를 pickle로 변환해서 전달해야 한다. 하지만 객체가 크거나, 너무 자주 직렬화하면 속도가 느려질 수 있다.

직렬화 오버헤드가 발생하는 이유

futures = [MyNetwork.process_data.remote(x, self.calculator) for x in data_list]

self.calculator는 일반적인 Python 객체이므로, Ray는 이를 다른 프로세스로 넘기기 위해 pickle을 이용해 직렬화한다. 100개의 Worker가 생성될 때마다 self.calculator를 계속 직렬화/역직렬화하기 때문에 불필요한 오버헤드가 발생한다.

실제로 네트워크를 pickle로 전달하는 과정에서 오히려 시간이 더 걸리는 경우가 많다.