style_text_rec.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from arch.base_module import MiddleNet, ResBlock
  17. from arch.encoder import Encoder
  18. from arch.decoder import Decoder, DecoderUnet, SingleDecoder
  19. from utils.load_params import load_dygraph_pretrain
  20. from utils.logging import get_logger
  21. class StyleTextRec(nn.Layer):
  22. def __init__(self, config):
  23. super(StyleTextRec, self).__init__()
  24. self.logger = get_logger()
  25. self.text_generator = TextGenerator(config["Predictor"][
  26. "text_generator"])
  27. self.bg_generator = BgGeneratorWithMask(config["Predictor"][
  28. "bg_generator"])
  29. self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
  30. "fusion_generator"])
  31. bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
  32. text_generator_pretrain = config["Predictor"]["text_generator"][
  33. "pretrain"]
  34. fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
  35. "pretrain"]
  36. load_dygraph_pretrain(
  37. self.bg_generator,
  38. self.logger,
  39. path=bg_generator_pretrain,
  40. load_static_weights=False)
  41. load_dygraph_pretrain(
  42. self.text_generator,
  43. self.logger,
  44. path=text_generator_pretrain,
  45. load_static_weights=False)
  46. load_dygraph_pretrain(
  47. self.fusion_generator,
  48. self.logger,
  49. path=fusion_generator_pretrain,
  50. load_static_weights=False)
  51. def forward(self, style_input, text_input):
  52. text_gen_output = self.text_generator.forward(style_input, text_input)
  53. fake_text = text_gen_output["fake_text"]
  54. fake_sk = text_gen_output["fake_sk"]
  55. bg_gen_output = self.bg_generator.forward(style_input)
  56. bg_encode_feature = bg_gen_output["bg_encode_feature"]
  57. bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
  58. bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
  59. fake_bg = bg_gen_output["fake_bg"]
  60. fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
  61. fake_fusion = fusion_gen_output["fake_fusion"]
  62. return {
  63. "fake_fusion": fake_fusion,
  64. "fake_text": fake_text,
  65. "fake_sk": fake_sk,
  66. "fake_bg": fake_bg,
  67. }
  68. class TextGenerator(nn.Layer):
  69. def __init__(self, config):
  70. super(TextGenerator, self).__init__()
  71. name = config["module_name"]
  72. encode_dim = config["encode_dim"]
  73. norm_layer = config["norm_layer"]
  74. conv_block_dropout = config["conv_block_dropout"]
  75. conv_block_num = config["conv_block_num"]
  76. conv_block_dilation = config["conv_block_dilation"]
  77. if norm_layer == "InstanceNorm2D":
  78. use_bias = True
  79. else:
  80. use_bias = False
  81. self.encoder_text = Encoder(
  82. name=name + "_encoder_text",
  83. in_channels=3,
  84. encode_dim=encode_dim,
  85. use_bias=use_bias,
  86. norm_layer=norm_layer,
  87. act="ReLU",
  88. act_attr=None,
  89. conv_block_dropout=conv_block_dropout,
  90. conv_block_num=conv_block_num,
  91. conv_block_dilation=conv_block_dilation)
  92. self.encoder_style = Encoder(
  93. name=name + "_encoder_style",
  94. in_channels=3,
  95. encode_dim=encode_dim,
  96. use_bias=use_bias,
  97. norm_layer=norm_layer,
  98. act="ReLU",
  99. act_attr=None,
  100. conv_block_dropout=conv_block_dropout,
  101. conv_block_num=conv_block_num,
  102. conv_block_dilation=conv_block_dilation)
  103. self.decoder_text = Decoder(
  104. name=name + "_decoder_text",
  105. encode_dim=encode_dim,
  106. out_channels=int(encode_dim / 2),
  107. use_bias=use_bias,
  108. norm_layer=norm_layer,
  109. act="ReLU",
  110. act_attr=None,
  111. conv_block_dropout=conv_block_dropout,
  112. conv_block_num=conv_block_num,
  113. conv_block_dilation=conv_block_dilation,
  114. out_conv_act="Tanh",
  115. out_conv_act_attr=None)
  116. self.decoder_sk = Decoder(
  117. name=name + "_decoder_sk",
  118. encode_dim=encode_dim,
  119. out_channels=1,
  120. use_bias=use_bias,
  121. norm_layer=norm_layer,
  122. act="ReLU",
  123. act_attr=None,
  124. conv_block_dropout=conv_block_dropout,
  125. conv_block_num=conv_block_num,
  126. conv_block_dilation=conv_block_dilation,
  127. out_conv_act="Sigmoid",
  128. out_conv_act_attr=None)
  129. self.middle = MiddleNet(
  130. name=name + "_middle_net",
  131. in_channels=int(encode_dim / 2) + 1,
  132. mid_channels=encode_dim,
  133. out_channels=3,
  134. use_bias=use_bias)
  135. def forward(self, style_input, text_input):
  136. style_feature = self.encoder_style.forward(style_input)["res_blocks"]
  137. text_feature = self.encoder_text.forward(text_input)["res_blocks"]
  138. fake_c_temp = self.decoder_text.forward([text_feature,
  139. style_feature])["out_conv"]
  140. fake_sk = self.decoder_sk.forward([text_feature,
  141. style_feature])["out_conv"]
  142. fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
  143. return {"fake_sk": fake_sk, "fake_text": fake_text}
  144. class BgGeneratorWithMask(nn.Layer):
  145. def __init__(self, config):
  146. super(BgGeneratorWithMask, self).__init__()
  147. name = config["module_name"]
  148. encode_dim = config["encode_dim"]
  149. norm_layer = config["norm_layer"]
  150. conv_block_dropout = config["conv_block_dropout"]
  151. conv_block_num = config["conv_block_num"]
  152. conv_block_dilation = config["conv_block_dilation"]
  153. self.output_factor = config.get("output_factor", 1.0)
  154. if norm_layer == "InstanceNorm2D":
  155. use_bias = True
  156. else:
  157. use_bias = False
  158. self.encoder_bg = Encoder(
  159. name=name + "_encoder_bg",
  160. in_channels=3,
  161. encode_dim=encode_dim,
  162. use_bias=use_bias,
  163. norm_layer=norm_layer,
  164. act="ReLU",
  165. act_attr=None,
  166. conv_block_dropout=conv_block_dropout,
  167. conv_block_num=conv_block_num,
  168. conv_block_dilation=conv_block_dilation)
  169. self.decoder_bg = SingleDecoder(
  170. name=name + "_decoder_bg",
  171. encode_dim=encode_dim,
  172. out_channels=3,
  173. use_bias=use_bias,
  174. norm_layer=norm_layer,
  175. act="ReLU",
  176. act_attr=None,
  177. conv_block_dropout=conv_block_dropout,
  178. conv_block_num=conv_block_num,
  179. conv_block_dilation=conv_block_dilation,
  180. out_conv_act="Tanh",
  181. out_conv_act_attr=None)
  182. self.decoder_mask = Decoder(
  183. name=name + "_decoder_mask",
  184. encode_dim=encode_dim // 2,
  185. out_channels=1,
  186. use_bias=use_bias,
  187. norm_layer=norm_layer,
  188. act="ReLU",
  189. act_attr=None,
  190. conv_block_dropout=conv_block_dropout,
  191. conv_block_num=conv_block_num,
  192. conv_block_dilation=conv_block_dilation,
  193. out_conv_act="Sigmoid",
  194. out_conv_act_attr=None)
  195. self.middle = MiddleNet(
  196. name=name + "_middle_net",
  197. in_channels=3 + 1,
  198. mid_channels=encode_dim,
  199. out_channels=3,
  200. use_bias=use_bias)
  201. def forward(self, style_input):
  202. encode_bg_output = self.encoder_bg(style_input)
  203. decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
  204. encode_bg_output["down2"],
  205. encode_bg_output["down1"])
  206. fake_c_temp = decode_bg_output["out_conv"]
  207. fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
  208. "res_blocks"])["out_conv"]
  209. fake_bg = self.middle(
  210. paddle.concat(
  211. (fake_c_temp, fake_bg_mask), axis=1))
  212. return {
  213. "bg_encode_feature": encode_bg_output["res_blocks"],
  214. "bg_decode_feature1": decode_bg_output["up1"],
  215. "bg_decode_feature2": decode_bg_output["up2"],
  216. "fake_bg": fake_bg,
  217. "fake_bg_mask": fake_bg_mask,
  218. }
  219. class FusionGeneratorSimple(nn.Layer):
  220. def __init__(self, config):
  221. super(FusionGeneratorSimple, self).__init__()
  222. name = config["module_name"]
  223. encode_dim = config["encode_dim"]
  224. norm_layer = config["norm_layer"]
  225. conv_block_dropout = config["conv_block_dropout"]
  226. conv_block_dilation = config["conv_block_dilation"]
  227. if norm_layer == "InstanceNorm2D":
  228. use_bias = True
  229. else:
  230. use_bias = False
  231. self._conv = nn.Conv2D(
  232. in_channels=6,
  233. out_channels=encode_dim,
  234. kernel_size=3,
  235. stride=1,
  236. padding=1,
  237. groups=1,
  238. weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
  239. bias_attr=False)
  240. self._res_block = ResBlock(
  241. name="{}_conv_block".format(name),
  242. channels=encode_dim,
  243. norm_layer=norm_layer,
  244. use_dropout=conv_block_dropout,
  245. use_dilation=conv_block_dilation,
  246. use_bias=use_bias)
  247. self._reduce_conv = nn.Conv2D(
  248. in_channels=encode_dim,
  249. out_channels=3,
  250. kernel_size=3,
  251. stride=1,
  252. padding=1,
  253. groups=1,
  254. weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
  255. bias_attr=False)
  256. def forward(self, fake_text, fake_bg):
  257. fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
  258. fake_concat_tmp = self._conv(fake_concat)
  259. output_res = self._res_block(fake_concat_tmp)
  260. fake_fusion = self._reduce_conv(output_res)
  261. return {"fake_fusion": fake_fusion}