123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Code is refer from:
- https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- from paddle import nn
- import paddle.nn.functional as F
- class PoolAggregate(nn.Layer):
- def __init__(self, n_r, d_in, d_middle=None, d_out=None):
- super(PoolAggregate, self).__init__()
- if not d_middle:
- d_middle = d_in
- if not d_out:
- d_out = d_in
- self.d_in = d_in
- self.d_middle = d_middle
- self.d_out = d_out
- self.act = nn.Swish()
- self.n_r = n_r
- self.aggs = self._build_aggs()
- def _build_aggs(self):
- aggs = []
- for i in range(self.n_r):
- aggs.append(
- self.add_sublayer(
- '{}'.format(i),
- nn.Sequential(
- ('conv1', nn.Conv2D(
- self.d_in, self.d_middle, 3, 2, 1, bias_attr=False)
- ), ('bn1', nn.BatchNorm(self.d_middle)),
- ('act', self.act), ('conv2', nn.Conv2D(
- self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
- )), ('bn2', nn.BatchNorm(self.d_out)))))
- return aggs
- def forward(self, x):
- b = x.shape[0]
- outs = []
- for agg in self.aggs:
- y = agg(x)
- p = F.adaptive_avg_pool2d(y, 1)
- outs.append(p.reshape((b, 1, self.d_out)))
- out = paddle.concat(outs, 1)
- return out
- class WeightAggregate(nn.Layer):
- def __init__(self, n_r, d_in, d_middle=None, d_out=None):
- super(WeightAggregate, self).__init__()
- if not d_middle:
- d_middle = d_in
- if not d_out:
- d_out = d_in
- self.n_r = n_r
- self.d_out = d_out
- self.act = nn.Swish()
- self.conv_n = nn.Sequential(
- ('conv1', nn.Conv2D(
- d_in, d_in, 3, 1, 1,
- bias_attr=False)), ('bn1', nn.BatchNorm(d_in)),
- ('act1', self.act), ('conv2', nn.Conv2D(
- d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)),
- ('act2', nn.Sigmoid()))
- self.conv_d = nn.Sequential(
- ('conv1', nn.Conv2D(
- d_in, d_middle, 3, 1, 1,
- bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)),
- ('act1', self.act), ('conv2', nn.Conv2D(
- d_middle, d_out, 1,
- bias_attr=False)), ('bn2', nn.BatchNorm(d_out)))
- def forward(self, x):
- b, _, h, w = x.shape
- hmaps = self.conv_n(x)
- fmaps = self.conv_d(x)
- r = paddle.bmm(
- hmaps.reshape((b, self.n_r, h * w)),
- fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)))
- return r
- class GCN(nn.Layer):
- def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1):
- super(GCN, self).__init__()
- if not d_out:
- d_out = d_in
- if not n_out:
- n_out = d_in
- self.conv_n = nn.Conv1D(n_in, n_out, 1)
- self.linear = nn.Linear(d_in, d_out)
- self.dropout = nn.Dropout(dropout)
- self.act = nn.Swish()
- def forward(self, x):
- x = self.conv_n(x)
- x = self.dropout(self.linear(x))
- return self.act(x)
- class PRENFPN(nn.Layer):
- def __init__(self, in_channels, n_r, d_model, max_len, dropout):
- super(PRENFPN, self).__init__()
- assert len(in_channels) == 3, "in_channels' length must be 3."
- c1, c2, c3 = in_channels # the depths are from big to small
- # build fpn
- assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model)
- self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3)
- self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3)
- self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3)
- self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3)
- self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3)
- self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3)
- self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout)
- self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout)
- self.out_channels = d_model
- def forward(self, inputs):
- f3, f5, f7 = inputs
- rp1 = self.agg_p1(f3)
- rp2 = self.agg_p2(f5)
- rp3 = self.agg_p3(f7)
- rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d]
- rw1 = self.agg_w1(f3)
- rw2 = self.agg_w2(f5)
- rw3 = self.agg_w3(f7)
- rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d]
- y1 = self.gcn_pool(rp)
- y2 = self.gcn_weight(rw)
- y = 0.5 * (y1 + y2)
- return y # [b,max_len,d]
|