LightGTS / test_zero-shot.py
pchen182224's picture
Upload 9 files
c882c3e verified
import argparse
import torch
import numpy as np
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from configuration_LightGTS import LightGTSConfig
from modeling_LightGTS import LightGTSForPrediction
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, MODEL_MAPPING
from transformers import AutoConfig
if __name__ == "__main__":
LightGTS_config = LightGTSConfig(context_points=528, c_in=1, target_dim=192, patch_len=48, stride=48)
LightGTS_config.save_pretrained("LightGTS-huggingface")
AutoConfig.register("LightGTS",LightGTSConfig)
AutoModelForCausalLM.register(LightGTSConfig, LightGTSForPrediction)
model = AutoModelForCausalLM.from_pretrained(
"./LightGTS-huggingface",
trust_remote_code=True
)
df1 = pd.read_csv("/home/wlf/LightGTS/LightGTS/data/predict_datasets/ETTh1.csv")
df2 = pd.read_csv("/home/wlf/LightGTS/LightGTS/data/predict_datasets/ETTh2.csv")
print(df1,df2)
start = 300
lookback_length = 576
lookback = torch.tensor(df1["HUFL"][start:start+lookback_length].values).unsqueeze(0).unsqueeze(-1).float()
all_length = 768
all = torch.tensor(df1["HUFL"][start:start+all_length].values).unsqueeze(0).unsqueeze(-1).float()
lookback2 = torch.tensor(df2["OT"][start:start+lookback_length].values).unsqueeze(0).unsqueeze(-1).float()
all2 = torch.tensor(df2["OT"][start:start+all_length].values).unsqueeze(0).unsqueeze(-1).float()
print(lookback.shape)
# zero-shot sample
outputs = model.generate(lookback, patch_len = 48, stride_len=48, max_output_length=192)
outputs2 = model.generate(lookback2, patch_len = 32, stride_len=32, max_output_length=192)
print(outputs2.shape)