| import numpy as np
|
| import random
|
| import torch
|
| import torch.nn as nn
|
| import os
|
| import inspect
|
| import pickle
|
| import gdown
|
| from network import Actor
|
|
|
|
|
| def weight_init(m):
|
| """Custom weight init for Conv2D and Linear layers.
|
| Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
|
|
|
| if isinstance(m, nn.Linear):
|
| nn.init.orthogonal_(m.weight.data)
|
| m.bias.data.fill_(0.0)
|
| elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
|
|
| assert m.weight.size(2) == m.weight.size(3)
|
| m.weight.data.fill_(0.0)
|
| m.bias.data.fill_(0.0)
|
| mid = m.weight.size(2) // 2
|
| gain = nn.init.calculate_gain('relu')
|
| nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
|
|
|
|
| def set_seed(random_seed):
|
| if random_seed <= 0:
|
| random_seed = np.random.randint(1, 9999)
|
| else:
|
| random_seed = random_seed
|
|
|
| torch.manual_seed(random_seed)
|
| np.random.seed(random_seed)
|
| random.seed(random_seed)
|
|
|
| return random_seed
|
|
|
|
|
| def make_env(env_name, seed):
|
| import gymnasium as gym
|
|
|
| env = gym.make(env_name)
|
| env.action_space.seed(seed)
|
| state_dim = env.observation_space.shape[0]
|
| action_dim = env.action_space.shape[0]
|
| action_bound = [env.action_space.low[0], env.action_space.high[0]]
|
|
|
| env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}
|
|
|
| return env, env_info
|
|
|
|
|
| def get_learning_info(args, seed):
|
| env, env_info = make_env(args.env_name, seed)
|
| device = 'cuda'
|
|
|
| alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
|
| 'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold}
|
|
|
| thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
|
| max_action = 1
|
|
|
| t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
|
| num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())
|
|
|
| kwargs = {
|
| "env": env,
|
| "args": args,
|
| "env_info": env_info,
|
| "thresholds": thresholds,
|
| "discount": args.discount,
|
| "datasize": args.datasize,
|
| "tau": args.tau,
|
| "device": device,
|
| "num_teacher_param": num_teacher_param,
|
| "noise_clip": args.noise_clip * max_action,
|
| "policy_freq": args.policy_freq,
|
| "h": args.h,
|
| }
|
| return kwargs
|
|
|
|
|
| def get_compression_ratio(num_teacher_param, agent):
|
| kep_w = 0
|
| for c in agent.actor.children():
|
| kep_w += c.get_num_remained_weights()
|
|
|
|
|
| return kep_w / num_teacher_param
|
|
|
|
|
| def load_buffer(env_name, level, datasize):
|
| current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
| file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
|
| try:
|
| with open(file_path, "rb") as fr:
|
| buffer = pickle.load(fr)
|
| buffer.size = datasize
|
| except FileNotFoundError:
|
|
|
| if level == 'expert':
|
| print("Downloading the teacher buffer...")
|
| if env_name == "Ant-v3":
|
| file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
|
| elif env_name == "Walker2d-v3":
|
| file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
|
| elif env_name == "HalfCheetah-v3":
|
| file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
|
| elif env_name == "Hopper-v3":
|
| file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
|
| else:
|
| raise ValueError("Invalid Environment Name")
|
|
|
| url = f"https://drive.google.com/uc?id={file_id}"
|
| gdown.download(url, file_path, quiet=False)
|
| print("Download Complete!")
|
| elif level == 'medium':
|
| if env_name == "Ant-v3":
|
| file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
|
| elif env_name == "Walker2d-v3":
|
| file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
|
| elif env_name == "HalfCheetah-v3":
|
| file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
|
| elif env_name == "Hopper-v3":
|
| file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
|
| else:
|
| raise ValueError("Invalid Environment Name")
|
|
|
| url = f"https://drive.google.com/uc?id={file_id}"
|
| gdown.download(url, file_path, quiet=False)
|
|
|
| else:
|
| raise ValueError("Invalid Level. Choose from ['expert', 'medium']")
|
|
|
| with open(file_path, "rb") as fr:
|
| buffer = pickle.load(fr)
|
| buffer.size = datasize
|
|
|
| return buffer
|
|
|