import os


class GPUManager:
    queries = (
        "index",
        "gpu_name",
        "memory.free",
        "memory.used",
        "memory.total",
        "power.draw",
        "power.limit",
    )

    def __init__(self):
        all_gpus = self.query_gpu(False)

    def get_info(self, ctype):
        cmd = "nvidia-smi --query-gpu={} --format=csv,noheader".format(ctype)
        lines = os.popen(cmd).readlines()
        lines = [line.strip("\n") for line in lines]
        return lines

    def query_gpu(self, show=True):
        num_gpus = len(self.get_info("index"))
        all_gpus = [{} for i in range(num_gpus)]
        for query in self.queries:
            infos = self.get_info(query)
            for idx, info in enumerate(infos):
                all_gpus[idx][query] = info

        if "CUDA_VISIBLE_DEVICES" in os.environ:
            CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
            selected_gpus = []
            for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
                find = False
                for gpu in all_gpus:
                    if gpu["index"] == CUDA_VISIBLE_DEVICE:
                        assert not find, "Duplicate cuda device index : {}".format(
                            CUDA_VISIBLE_DEVICE
                        )
                        find = True
                        selected_gpus.append(gpu.copy())
                        selected_gpus[-1]["index"] = "{}".format(idx)
                assert find, "Does not find the device : {}".format(CUDA_VISIBLE_DEVICE)
            all_gpus = selected_gpus

        if show:
            allstrings = ""
            for gpu in all_gpus:
                string = "| "
                for query in self.queries:
                    if query.find("memory") == 0:
                        xinfo = "{:>9}".format(gpu[query])
                    else:
                        xinfo = gpu[query]
                    string = string + query + " : " + xinfo + " | "
                allstrings = allstrings + string + "\n"
            return allstrings
        else:
            return all_gpus

    def select_by_memory(self, numbers=1):
        all_gpus = self.query_gpu(False)
        assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(
            numbers
        )
        alls = []
        for idx, gpu in enumerate(all_gpus):
            free_memory = gpu["memory.free"]
            free_memory = free_memory.split(" ")[0]
            free_memory = int(free_memory)
            index = gpu["index"]
            alls.append((free_memory, index))
        alls.sort(reverse=True)
        alls = [int(alls[i][1]) for i in range(numbers)]
        return sorted(alls)


"""
if __name__ == '__main__':
  manager = GPUManager()
  manager.query_gpu(True)
  indexes = manager.select_by_memory(3)
  print (indexes)
"""