| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import ray |
| import torch |
| from tensordict import TensorDict |
|
|
| from verl import DataProto |
| from verl.single_controller.base.worker import Worker |
| from verl.single_controller.ray import RayWorkerGroup |
| from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool |
|
|
| os.environ["RAY_DEDUP_LOGS"] = "0" |
| os.environ["NCCL_DEBUG"] = "WARN" |
|
|
|
|
| @ray.remote |
| class ModelActor(Worker): |
| def __init__(self): |
| pass |
|
|
|
|
| class HackSelf: |
| def __init__(self): |
| pass |
|
|
|
|
| def get_aux_metrics(self, test_proto): |
| sequence_ids = test_proto.batch["sequence_ids"] |
| decode_count = [] |
| for i in range(sequence_ids.size(0)): |
| decode_count.append(len(sequence_ids[i].tolist())) |
| ret_proto = DataProto( |
| batch=TensorDict( |
| {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0) |
| ) |
| ) |
| return ret_proto |
|
|
|
|
| def test(): |
| |
| ray.init() |
|
|
| |
| resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a") |
|
|
| class_with_args = RayClassWithInitArgs(cls=ModelActor) |
| shard_wg = RayWorkerGroup(resource_pool, class_with_args) |
|
|
| test_bs = 8 |
| test_proto = DataProto( |
| TensorDict( |
| { |
| "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), |
| }, |
| batch_size=test_bs, |
| ), |
| meta_info={"query_length": 1536}, |
| ) |
|
|
| |
| ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto) |
|
|
| |
| hs = HackSelf() |
| ret_proto2 = get_aux_metrics(hs, test_proto) |
|
|
| torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"]) |
|
|
| ray.shutdown() |
|
|