| import sys, os, json |
| root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
| sys.path.append(root) |
| os.chdir(root) |
|
|
| |
| import torch |
| from torch import nn |
| |
| from workspace.condition import generalization as item |
| train_set = item.train_set |
| test_set = item.test_set |
| test_set.set_infinite_dataset(max_num=test_set.real_length) |
| print("num_generated:", test_set.real_length) |
| config = item.config |
| model = item.model |
| assert config.get("tag") is not None, "Remember to set a tag." |
|
|
|
|
|
|
|
|
| generate_config = { |
| "device": "cuda", |
| "checkpoint": f"./checkpoint/{config['tag']}.pth", |
| "generated_path": os.path.join(test_set.generated_path.rsplit("/", 1)[0], "generated_{}_{}.pth"), |
| "test_command": os.path.join(test_set.test_command.rsplit("/", 1)[0], "generated_{}_{}.pth"), |
| "need_test": True, |
| "specific_item": None, |
| } |
| config.update(generate_config) |
|
|
|
|
|
|
|
|
| |
| print('==> Building model..') |
| diction = torch.load(config["checkpoint"]) |
| permutation_shape = diction["to_permutation_state.weight"].shape |
| model.to_permutation_state = nn.Embedding(*permutation_shape) |
| model.load_state_dict(diction) |
| model = model.to(config["device"]) |
|
|
|
|
| |
| print('==> Defining generate..') |
| def generate(save_path=config["generated_path"], test_command=config["test_command"], need_test=True, index=None): |
| print("\n==> Generating..") |
| model.eval() |
| _, condition = test_set[index] |
| class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4) |
| with torch.no_grad(): |
| prediction = model(sample=True, condition=condition[None], permutation_state=False) |
| generated_norm = torch.nanmean((prediction.cpu()).abs()) |
| print("Generated_norm:", generated_norm.item()) |
| train_set.save_params(prediction, save_path=save_path.format(config["tag"], f"class{class_index}")) |
| if need_test: |
| os.system(test_command.format(config["tag"], f"class{class_index}")) |
| model.train() |
| return prediction |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| for i in range(len(test_set)): |
| if config["specific_item"] is not None: |
| assert isinstance(config["specific_item"], int) |
| i = config["specific_item"] |
| print("Save to", config["generated_path"].format(config["tag"], "classXXX")) |
| generate( |
| save_path=config["generated_path"], |
| test_command=config["test_command"], |
| need_test=config["need_test"], |
| index=i, |
| ) |