在完成源码编写后,有些类或者函数需要通过命令行传递参数,可以使用fire工具包:
下面的代码中,如果仅需要给add传递参数,将add暴露给命令行的方法只需要在文件末尾添加:fire.Fire()
对类中的parse函数,需要暴露给命令的方法为在文件末尾添加:fire.Fire(DefaultConfig)
一般不同时暴露上面的两个。
import fire
class DefaultConfig(object):
def __init__(self):
self.env = 'default' # visdom environment
self.model = 'AlexNet' # model name must be the same with models in 'models/__init__.py'
self.train_data_root = '/home/sjtuer/Desktop/DRL/S6/data/train'
self.test_data_root = '/home/sjtuer/Desktop/DRL/S6/data/test'
self.load_model_path = None # 'checkpoint/model.py' # load the pre-trained model, if None: no load
self.batch_size = 128
self.use_gpu = True
self.num_workers = 4 # how many workers for loading data
self.print_freq = 20 # print info every N batch
self.debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb
self.result_file = 'result.csv'
self.max_epoch = 10
self.lr = 0.001
self.lr_decay = 0.95 # when val_loss increase, lr = lr * lr_decay
weight_decay = 1e-4 # loss function
def parse(self, **kwargs):
'''
update configuration based on kwargs dictionary
'''
for k,v in kwargs.items():
if not hasattr(self,k):
warning.warn("Warnnin: opt has not attribut %s" %k)
setattr(self, k, v)
# print(self.__dict__)
# print(self.__class__)
print('user config:')
for k, v in self.__dict__.items(): # __class__ 表示当前操作的对象的类
# print(k)
if not k.startswith('__'):
print(k, getattr(self, k))
opt = DefaultConfig()
def add(a,b):
return a+b
if __name__ == '__main__':
dc = DefaultConfig()
print(dc.batch_size)
dc.parse()
print(dc.batch_size)
fire.Fire()
转载:https://blog.csdn.net/weixin_37532614/article/details/104616320
查看评论