| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import gc |
| import time |
|
|
| import ray |
|
|
| from verl.single_controller.base.worker import Worker |
| from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool |
|
|
|
|
| @ray.remote |
| class TestActor(Worker): |
| |
| def __init__(self, cuda_visible_devices=None) -> None: |
| super().__init__(cuda_visible_devices) |
|
|
| def get_node_id(self): |
| return ray.get_runtime_context().get_node_id() |
|
|
|
|
| def test(): |
| ray.init() |
|
|
| |
| print("test single-node-no-partition") |
| resource_pool = RayResourcePool([8], use_gpu=True) |
|
|
| class_with_args = RayClassWithInitArgs(cls=TestActor) |
|
|
| print("create actor worker group") |
| actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor") |
| print("create critic worker group") |
| critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic") |
| print("create rm worker group") |
| rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm") |
| print("create ref worker group") |
| ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref") |
|
|
| assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
| assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
| assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
| assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
|
|
| del actor_wg |
| del critic_wg |
| del rm_wg |
| del ref_wg |
| gc.collect() |
|
|
| [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()] |
| print("wait 5s to remove placemeng_group") |
| time.sleep(5) |
| |
|
|
| print("test single-node-multi-partition") |
| rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm") |
| ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref") |
| total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool) |
|
|
| assert rm_resource_pool.world_size == 4 |
| assert ref_resource_pool.world_size == 4 |
| assert total_resource_pool.world_size == 8 |
|
|
| actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor") |
| critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic") |
| rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm") |
| ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref") |
|
|
| assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
| assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] |
| assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)] |
| assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)] |
|
|
| ray.shutdown() |
|
|