config.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. import os
  16. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  17. def override(dl, ks, v):
  18. """
  19. Recursively replace dict of list
  20. Args:
  21. dl(dict or list): dict or list to be replaced
  22. ks(list): list of keys
  23. v(str): value to be replaced
  24. """
  25. def str2num(v):
  26. try:
  27. return eval(v)
  28. except Exception:
  29. return v
  30. assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
  31. assert len(ks) > 0, ('lenght of keys should larger than 0')
  32. if isinstance(dl, list):
  33. k = str2num(ks[0])
  34. if len(ks) == 1:
  35. assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
  36. dl[k] = str2num(v)
  37. else:
  38. override(dl[k], ks[1:], v)
  39. else:
  40. if len(ks) == 1:
  41. #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  42. if not ks[0] in dl:
  43. logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
  44. dl[ks[0]] = str2num(v)
  45. else:
  46. assert ks[0] in dl, (
  47. '({}) doesn\'t exist in {}, a new dict field is invalid'.
  48. format(ks[0], dl))
  49. override(dl[ks[0]], ks[1:], v)
  50. def override_config(config, options=None):
  51. """
  52. Recursively override the config
  53. Args:
  54. config(dict): dict to be replaced
  55. options(list): list of pairs(key0.key1.idx.key2=value)
  56. such as: [
  57. 'topk=2',
  58. 'VALID.transforms.1.ResizeImage.resize_short=300'
  59. ]
  60. Returns:
  61. config(dict): replaced config
  62. """
  63. if options is not None:
  64. for opt in options:
  65. assert isinstance(opt, str), (
  66. "option({}) should be a str".format(opt))
  67. assert "=" in opt, (
  68. "option({}) should contain a ="
  69. "to distinguish between key and value".format(opt))
  70. pair = opt.split('=')
  71. assert len(pair) == 2, ("there can be only a = in the option")
  72. key, value = pair
  73. keys = key.split('.')
  74. override(config, keys, value)
  75. return config
  76. class ArgsParser(ArgumentParser):
  77. def __init__(self):
  78. super(ArgsParser, self).__init__(
  79. formatter_class=RawDescriptionHelpFormatter)
  80. self.add_argument("-c", "--config", help="configuration file to use")
  81. self.add_argument(
  82. "-t", "--tag", default="0", help="tag for marking worker")
  83. self.add_argument(
  84. '-o',
  85. '--override',
  86. action='append',
  87. default=[],
  88. help='config options to be overridden')
  89. self.add_argument(
  90. "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
  91. self.add_argument(
  92. "--text_corpus", default="PaddleOCR", help="tag for marking worker")
  93. self.add_argument(
  94. "--language", default="en", help="tag for marking worker")
  95. def parse_args(self, argv=None):
  96. args = super(ArgsParser, self).parse_args(argv)
  97. assert args.config is not None, \
  98. "Please specify --config=configure_file_path."
  99. return args
  100. def load_config(file_path):
  101. """
  102. Load config from yml/yaml file.
  103. Args:
  104. file_path (str): Path of the config file to be loaded.
  105. Returns: config
  106. """
  107. ext = os.path.splitext(file_path)[1]
  108. assert ext in ['.yml', '.yaml'], "only support yaml files for now"
  109. with open(file_path, 'rb') as f:
  110. config = yaml.load(f, Loader=yaml.Loader)
  111. return config
  112. def gen_config():
  113. base_config = {
  114. "Global": {
  115. "algorithm": "SRNet",
  116. "use_gpu": True,
  117. "start_epoch": 1,
  118. "stage1_epoch_num": 100,
  119. "stage2_epoch_num": 100,
  120. "log_smooth_window": 20,
  121. "print_batch_step": 2,
  122. "save_model_dir": "./output/SRNet",
  123. "use_visualdl": False,
  124. "save_epoch_step": 10,
  125. "vgg_pretrain": "./pretrained/VGG19_pretrained",
  126. "vgg_load_static_pretrain": True
  127. },
  128. "Architecture": {
  129. "model_type": "data_aug",
  130. "algorithm": "SRNet",
  131. "net_g": {
  132. "name": "srnet_net_g",
  133. "encode_dim": 64,
  134. "norm": "batch",
  135. "use_dropout": False,
  136. "init_type": "xavier",
  137. "init_gain": 0.02,
  138. "use_dilation": 1
  139. },
  140. # input_nc, ndf, netD,
  141. # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
  142. "bg_discriminator": {
  143. "name": "srnet_bg_discriminator",
  144. "input_nc": 6,
  145. "ndf": 64,
  146. "netD": "basic",
  147. "norm": "none",
  148. "init_type": "xavier",
  149. },
  150. "fusion_discriminator": {
  151. "name": "srnet_fusion_discriminator",
  152. "input_nc": 6,
  153. "ndf": 64,
  154. "netD": "basic",
  155. "norm": "none",
  156. "init_type": "xavier",
  157. }
  158. },
  159. "Loss": {
  160. "lamb": 10,
  161. "perceptual_lamb": 1,
  162. "muvar_lamb": 50,
  163. "style_lamb": 500
  164. },
  165. "Optimizer": {
  166. "name": "Adam",
  167. "learning_rate": {
  168. "name": "lambda",
  169. "lr": 0.0002,
  170. "lr_decay_iters": 50
  171. },
  172. "beta1": 0.5,
  173. "beta2": 0.999,
  174. },
  175. "Train": {
  176. "batch_size_per_card": 8,
  177. "num_workers_per_card": 4,
  178. "dataset": {
  179. "delimiter": "\t",
  180. "data_dir": "/",
  181. "label_file": "tmp/label.txt",
  182. "transforms": [{
  183. "DecodeImage": {
  184. "to_rgb": True,
  185. "to_np": False,
  186. "channel_first": False
  187. }
  188. }, {
  189. "NormalizeImage": {
  190. "scale": 1. / 255.,
  191. "mean": [0.485, 0.456, 0.406],
  192. "std": [0.229, 0.224, 0.225],
  193. "order": None
  194. }
  195. }, {
  196. "ToCHWImage": None
  197. }]
  198. }
  199. }
  200. }
  201. with open("config.yml", "w") as f:
  202. yaml.dump(base_config, f)
  203. if __name__ == '__main__':
  204. gen_config()