det_r18_vd_ct.yml 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. Global:
  2. use_gpu: true
  3. epoch_num: 600
  4. log_smooth_window: 20
  5. print_batch_step: 10
  6. save_model_dir: ./output/det_ct/
  7. save_epoch_step: 10
  8. # evaluation is run every 2000 iterations
  9. eval_batch_step: [0,1000]
  10. cal_metric_during_train: False
  11. pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams
  12. checkpoints:
  13. save_inference_dir:
  14. use_visualdl: False
  15. infer_img: doc/imgs_en/img623.jpg
  16. save_res_path: ./output/det_ct/predicts_ct.txt
  17. Architecture:
  18. model_type: det
  19. algorithm: CT
  20. Transform:
  21. Backbone:
  22. name: ResNet_vd
  23. layers: 18
  24. Neck:
  25. name: CTFPN
  26. Head:
  27. name: CT_Head
  28. in_channels: 512
  29. hidden_dim: 128
  30. num_classes: 3
  31. Loss:
  32. name: CTLoss
  33. Optimizer:
  34. name: Adam
  35. lr: #PolynomialDecay
  36. name: Linear
  37. learning_rate: 0.001
  38. end_lr: 0.
  39. epochs: 600
  40. step_each_epoch: 1254
  41. power: 0.9
  42. PostProcess:
  43. name: CTPostProcess
  44. box_type: poly
  45. Metric:
  46. name: CTMetric
  47. main_indicator: f_score
  48. Train:
  49. dataset:
  50. name: SimpleDataSet
  51. data_dir: ./train_data/total_text/train
  52. label_file_list:
  53. - ./train_data/total_text/train/train.txt
  54. ratio_list: [1.0]
  55. transforms:
  56. - DecodeImage:
  57. img_mode: RGB
  58. channel_first: False
  59. - CTLabelEncode: # Class handling label
  60. - RandomScale:
  61. - MakeShrink:
  62. - GroupRandomHorizontalFlip:
  63. - GroupRandomRotate:
  64. - GroupRandomCropPadding:
  65. - MakeCentripetalShift:
  66. - ColorJitter:
  67. brightness: 0.125
  68. saturation: 0.5
  69. - ToCHWImage:
  70. - NormalizeImage:
  71. - KeepKeys:
  72. keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list
  73. loader:
  74. shuffle: True
  75. drop_last: True
  76. batch_size_per_card: 4
  77. num_workers: 8
  78. Eval:
  79. dataset:
  80. name: SimpleDataSet
  81. data_dir: ./train_data/total_text/test
  82. label_file_list:
  83. - ./train_data/total_text/test/test.txt
  84. ratio_list: [1.0]
  85. transforms:
  86. - DecodeImage:
  87. img_mode: RGB
  88. channel_first: False
  89. - CTLabelEncode: # Class handling label
  90. - ScaleAlignedShort:
  91. - NormalizeImage:
  92. order: 'hwc'
  93. - ToCHWImage:
  94. - KeepKeys:
  95. keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list
  96. loader:
  97. shuffle: False
  98. drop_last: False
  99. batch_size_per_card: 1
  100. num_workers: 2