| | try: |
| | import cPickle as pickle |
| | except: |
| | import pickle |
| | import ast |
| | import re |
| | import inspect |
| | import os |
| | import logging |
| | import numpy as np |
| |
|
| | def cross_entropy_npy(a, b): |
| | return a * np.log(b + 1E-9) + (1 - a) * np.log(1 - b + 1E-9) |
| |
|
| |
|
| | def safe_eval(expr): |
| | if type(expr) is str: |
| | return ast.literal_eval(expr) |
| | else: |
| | return expr |
| |
|
| |
|
| | def logging_config(folder=None, name=None, |
| | level=logging.INFO, |
| | console_level=logging.DEBUG): |
| | """ |
| | |
| | Parameters |
| | ---------- |
| | folder : str or None |
| | name : str or None |
| | level : int |
| | console_level |
| | |
| | Returns |
| | ------- |
| | |
| | """ |
| | if name is None: |
| | name = inspect.stack()[1][1].split('.')[0] |
| | if folder is None: |
| | folder = os.path.join(os.getcwd(), name) |
| | if not os.path.exists(folder): |
| | os.makedirs(folder) |
| | |
| | for handler in logging.root.handlers: |
| | logging.root.removeHandler(handler) |
| | logging.root.handlers = [] |
| | logpath = os.path.join(folder, name + ".log") |
| | print("All Logs will be saved to %s" %logpath) |
| | logging.root.setLevel(level) |
| | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| | logfile = logging.FileHandler(logpath) |
| | logfile.setLevel(level) |
| | logfile.setFormatter(formatter) |
| | logging.root.addHandler(logfile) |
| | |
| | logconsole = logging.StreamHandler() |
| | logconsole.setLevel(console_level) |
| | logconsole.setFormatter(formatter) |
| | logging.root.addHandler(logconsole) |
| | return folder |
| |
|
| |
|
| | def load_params(prefix, epoch): |
| | """ |
| | |
| | Parameters |
| | ---------- |
| | prefix : str |
| | epoch : int |
| | |
| | Returns |
| | ------- |
| | arg_params : dict |
| | aux_params : dict |
| | """ |
| | import mxnet.ndarray as nd |
| | save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) |
| | arg_params = {} |
| | aux_params = {} |
| | for k, v in save_dict.items(): |
| | tp, name = k.split(':', 1) |
| | if tp == 'arg': |
| | arg_params[name] = v |
| | if tp == 'aux': |
| | aux_params[name] = v |
| | return arg_params, aux_params |
| |
|
| |
|
| | def parse_ctx(ctx_args): |
| | import mxnet as mx |
| | ctx = re.findall('([a-z]+)(\d*)', ctx_args) |
| | ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx] |
| | ctx = [mx.Context(*ele) for ele in ctx] |
| | return ctx |
| |
|