123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import yaml
- import os
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- def override(dl, ks, v):
- """
- Recursively replace dict of list
- Args:
- dl(dict or list): dict or list to be replaced
- ks(list): list of keys
- v(str): value to be replaced
- """
- def str2num(v):
- try:
- return eval(v)
- except Exception:
- return v
- assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
- assert len(ks) > 0, ('lenght of keys should larger than 0')
- if isinstance(dl, list):
- k = str2num(ks[0])
- if len(ks) == 1:
- assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
- dl[k] = str2num(v)
- else:
- override(dl[k], ks[1:], v)
- else:
- if len(ks) == 1:
- #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
- if not ks[0] in dl:
- logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
- dl[ks[0]] = str2num(v)
- else:
- assert ks[0] in dl, (
- '({}) doesn\'t exist in {}, a new dict field is invalid'.
- format(ks[0], dl))
- override(dl[ks[0]], ks[1:], v)
- def override_config(config, options=None):
- """
- Recursively override the config
- Args:
- config(dict): dict to be replaced
- options(list): list of pairs(key0.key1.idx.key2=value)
- such as: [
- 'topk=2',
- 'VALID.transforms.1.ResizeImage.resize_short=300'
- ]
- Returns:
- config(dict): replaced config
- """
- if options is not None:
- for opt in options:
- assert isinstance(opt, str), (
- "option({}) should be a str".format(opt))
- assert "=" in opt, (
- "option({}) should contain a ="
- "to distinguish between key and value".format(opt))
- pair = opt.split('=')
- assert len(pair) == 2, ("there can be only a = in the option")
- key, value = pair
- keys = key.split('.')
- override(config, keys, value)
- return config
- class ArgsParser(ArgumentParser):
- def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
- self.add_argument("-c", "--config", help="configuration file to use")
- self.add_argument(
- "-t", "--tag", default="0", help="tag for marking worker")
- self.add_argument(
- '-o',
- '--override',
- action='append',
- default=[],
- help='config options to be overridden')
- self.add_argument(
- "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
- self.add_argument(
- "--text_corpus", default="PaddleOCR", help="tag for marking worker")
- self.add_argument(
- "--language", default="en", help="tag for marking worker")
- def parse_args(self, argv=None):
- args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
- return args
- def load_config(file_path):
- """
- Load config from yml/yaml file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: config
- """
- ext = os.path.splitext(file_path)[1]
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- with open(file_path, 'rb') as f:
- config = yaml.load(f, Loader=yaml.Loader)
- return config
- def gen_config():
- base_config = {
- "Global": {
- "algorithm": "SRNet",
- "use_gpu": True,
- "start_epoch": 1,
- "stage1_epoch_num": 100,
- "stage2_epoch_num": 100,
- "log_smooth_window": 20,
- "print_batch_step": 2,
- "save_model_dir": "./output/SRNet",
- "use_visualdl": False,
- "save_epoch_step": 10,
- "vgg_pretrain": "./pretrained/VGG19_pretrained",
- "vgg_load_static_pretrain": True
- },
- "Architecture": {
- "model_type": "data_aug",
- "algorithm": "SRNet",
- "net_g": {
- "name": "srnet_net_g",
- "encode_dim": 64,
- "norm": "batch",
- "use_dropout": False,
- "init_type": "xavier",
- "init_gain": 0.02,
- "use_dilation": 1
- },
- # input_nc, ndf, netD,
- # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
- "bg_discriminator": {
- "name": "srnet_bg_discriminator",
- "input_nc": 6,
- "ndf": 64,
- "netD": "basic",
- "norm": "none",
- "init_type": "xavier",
- },
- "fusion_discriminator": {
- "name": "srnet_fusion_discriminator",
- "input_nc": 6,
- "ndf": 64,
- "netD": "basic",
- "norm": "none",
- "init_type": "xavier",
- }
- },
- "Loss": {
- "lamb": 10,
- "perceptual_lamb": 1,
- "muvar_lamb": 50,
- "style_lamb": 500
- },
- "Optimizer": {
- "name": "Adam",
- "learning_rate": {
- "name": "lambda",
- "lr": 0.0002,
- "lr_decay_iters": 50
- },
- "beta1": 0.5,
- "beta2": 0.999,
- },
- "Train": {
- "batch_size_per_card": 8,
- "num_workers_per_card": 4,
- "dataset": {
- "delimiter": "\t",
- "data_dir": "/",
- "label_file": "tmp/label.txt",
- "transforms": [{
- "DecodeImage": {
- "to_rgb": True,
- "to_np": False,
- "channel_first": False
- }
- }, {
- "NormalizeImage": {
- "scale": 1. / 255.,
- "mean": [0.485, 0.456, 0.406],
- "std": [0.229, 0.224, 0.225],
- "order": None
- }
- }, {
- "ToCHWImage": None
- }]
- }
- }
- }
- with open("config.yml", "w") as f:
- yaml.dump(base_config, f)
- if __name__ == '__main__':
- gen_config()
|