#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
from tqdm import tqdm
from xautodl.procedures import prepare_seed, prepare_logger
from xautodl.datasets.synthetic_core import get_synthetic_env


def train_model(model, dataset, lr, epochs):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True)
    best_loss, best_param = None, None
    for _iepoch in range(epochs):
        preds = model(dataset.x)
        optimizer.zero_grad()
        loss = criterion(preds, dataset.y)
        loss.backward()
        optimizer.step()
        # save best
        if best_loss is None or best_loss > loss.item():
            best_loss = loss.item()
            best_param = copy.deepcopy(model.state_dict())
    model.load_state_dict(best_param)
    return best_loss


class TimeData:
    def __init__(self, timestamp, xs, ys):
        self._timestamp = timestamp
        self._xs = xs
        self._ys = ys

    @property
    def x(self):
        return self._xs

    @property
    def y(self):
        return self._ys

    @property
    def timestamp(self):
        return self._timestamp

    def __repr__(self):
        return "{name}(timestamp={timestamp}, with {num} samples)".format(
            name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs)
        )