"""
This module contains the encoder and decoder modules which
are components of the Orca models.
"""
import numpy as np
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
# can be set to lower values to decrease memory usage
# at least 4000 * 50 recommended for performance
Blocksize = 4000 * 200
[docs]class Decoder(nn.Module):
def __init__(self):
"""
Orca decoder architecture.
"""
super(Decoder, self).__init__()
self.lconvtwos = nn.ModuleList(
[
nn.Sequential(
nn.Dropout(p=0.1),
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
]
)
self.convtwos = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
]
)
self.final = nn.Sequential(
nn.Conv2d(64, 5, kernel_size=(1, 1), padding=0, dilation=1),
nn.BatchNorm2d(5),
nn.ReLU(inplace=True),
nn.Conv2d(5, 1, kernel_size=(1, 1), padding=0, dilation=1),
)
self.upsample = nn.Upsample(scale_factor=(2, 2))
self.lcombiner = nn.Sequential(
nn.Dropout(p=0.1),
nn.Conv2d(65, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
)
self.combiner = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.lcombinerD = nn.Sequential(
nn.Conv2d(129, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
)
self.combinerD = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
[docs] def forward(self, x, distenc, y=None):
mat = x[:, :, :, None] + x[:, :, None, :]
mat = torch.cat([mat, distenc], axis=1)
mat = self.lcombinerD(mat)
mat = self.combinerD(mat) + mat
if y is not None:
mat = torch.cat([mat, self.upsample(y)], axis=1)
cur = mat
first = True
for lm, m in zip(self.lconvtwos, self.convtwos):
if first:
if y is not None:
cur = self.lcombiner(cur)
cur = self.combiner(cur) + cur
else:
cur = lm(cur)
cur = m(cur) + cur
first = False
else:
lout = lm(cur)
if lout.size() == cur.size():
cur = lout + cur
else:
cur = lout
cur = m(cur) + cur
cur = self.final(cur)
return 0.5 * cur + 0.5 * cur.transpose(2, 3)
[docs]class Decoder_1m(nn.Module):
def __init__(self):
"""
Decoder for training the 1Mb module. Used for pretraining
the Encoder or used with the Encoder as a
standalone 1Mb model.
"""
super(Decoder_1m, self).__init__()
self.lconvtwos = nn.ModuleList(
[
nn.Sequential(
nn.Dropout(p=0.1),
nn.Conv2d(128, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
]
)
self.convtwos = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
]
)
self.final = nn.Sequential(
nn.Conv2d(64, 5, kernel_size=(1, 1), padding=0),
nn.BatchNorm2d(5),
nn.ReLU(inplace=True),
nn.Conv2d(5, 1, kernel_size=(1, 1), padding=0),
)
[docs] def forward(self, x):
mat = x[:, :, :, None] + x[:, :, None, :]
cur = mat
first = True
for lm, m in zip(self.lconvtwos, self.convtwos):
if first:
cur = lm(cur)
cur = m(cur) + cur
first = False
else:
lout = lm(cur)
if lout.size() == cur.size():
cur = lout + cur
else:
cur = lout
cur = m(cur) + cur
cur = self.final(cur)
return 0.5 * cur + 0.5 * cur.transpose(2, 3)
[docs]class Encoder(nn.Module):
def __init__(self):
"""
The first section of the Orca Encoder (sequence at bp resolution
to 4kb resolution)
"""
super(Encoder, self).__init__()
self.lconv1 = nn.Sequential(
nn.Conv1d(4, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
)
self.conv1 = nn.Sequential(
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
)
self.lconv2 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Conv1d(64, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
)
self.conv2 = nn.Sequential(
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.ReLU(inplace=True),
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.ReLU(inplace=True),
)
self.lconv3 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Conv1d(96, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv3 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv4 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv4 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv5 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv5 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv6 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv6 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv7 = nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv7 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
[docs] def forward(self, x):
"""Forward propagation of a batch."""
binsize = 4000
x_padding = 112000
x_block = Blocksize
def run(x, dummy):
lout1 = self.lconv1(x)
out1 = self.conv1(lout1)
lout2 = self.lconv2(out1 + lout1)
out2 = self.conv2(lout2)
lout3 = self.lconv3(out2 + lout2)
out3 = self.conv3(lout3)
lout4 = self.lconv4(out3 + lout3)
out4 = self.conv4(lout4)
lout5 = self.lconv5(out4 + lout4)
out5 = self.conv5(lout5)
lout6 = self.lconv6(out5 + lout5)
out6 = self.conv6(lout6)
lout7 = self.lconv7(out6 + lout6)
out7 = self.conv7(lout7)
return out7
dummy = torch.Tensor(1)
dummy.requires_grad = True
segouts = []
starts = np.arange(0, x.size(2), x_block)
for start in starts:
if start == starts[0]:
segouts.append(
checkpoint(run, x[:, :, start : start + x_block + x_padding], dummy)[
:, :, : int(x_block / binsize)
]
)
elif start == starts[-1]:
segouts.append(
checkpoint(run, x[:, :, start - x_padding :], dummy)[
:, :, int(x_padding / binsize) :
]
)
else:
segouts.append(
checkpoint(
run, x[:, :, start - x_padding : start + x_block + x_padding], dummy,
)[
:, :, int(x_padding / binsize) : int((x_block + x_padding) / binsize),
]
)
out = torch.cat(segouts, 2)
return out
[docs]class Encoder2(nn.Module):
def __init__(self):
"""
The second section of the Orca Encoder (4kb resolution to
128kb resolution)
"""
super(Encoder2, self).__init__()
self.lblocks = nn.ModuleList(
[
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
]
)
self.blocks = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
]
)
[docs] def forward(self, x):
"""Forward propagation of a batch."""
out = x
encodings = [out]
for lconv, conv in zip(self.lblocks, self.blocks):
lout = lconv(out)
out = conv(lout) + lout
encodings.append(out)
return encodings
[docs]class Encoder3(nn.Module):
def __init__(self):
"""
The third section of the Orca Encoder (128kb resolution to
1024kb resolution)
"""
super(Encoder3, self).__init__()
self.lblocks = nn.ModuleList(
[
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
),
]
)
self.blocks = nn.ModuleList(
[
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
),
]
)
[docs] def forward(self, x):
"""Forward propagation of a batch.
"""
out = x
encodings = [out]
for lconv, conv in zip(self.lblocks, self.blocks):
lout = lconv(out)
out = conv(lout) + lout
encodings.append(out)
return encodings
[docs]class Net(nn.Module):
def __init__(self, num_1d=None):
"""
Orca 1Mb model. The trained model weighted can be
loaded into Encoder and Decoder_1m modules.
Parameters
----------
num_1d : int or None, optional
The number of 1D targets used for the auxiliary
task of predicting ChIP-seq profiles.
"""
super(Net, self).__init__()
self.lconv1 = nn.Sequential(
nn.Conv1d(4, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
)
self.conv1 = nn.Sequential(
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
)
self.lconv2 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Conv1d(64, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
)
self.conv2 = nn.Sequential(
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.ReLU(inplace=True),
nn.Conv1d(96, 96, kernel_size=9, padding=4),
nn.BatchNorm1d(96),
nn.ReLU(inplace=True),
)
self.lconv3 = nn.Sequential(
nn.MaxPool1d(kernel_size=4, stride=4),
nn.Conv1d(96, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv3 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv4 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv4 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv5 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv5 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv6 = nn.Sequential(
nn.MaxPool1d(kernel_size=5, stride=5),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv6 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconv7 = nn.Sequential(
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
)
self.conv7 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 128, kernel_size=9, padding=4),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
)
self.lconvtwos = nn.ModuleList(
[
nn.Sequential(
nn.Dropout(p=0.1),
nn.Conv2d(128, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
),
]
)
self.convtwos = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=2, dilation=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=4, dilation=4),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=8, dilation=8),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=16, dilation=16),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=32, dilation=32),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(64, 32, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=(3, 3), padding=64, dilation=64),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
),
]
)
self.final = nn.Sequential(
nn.Conv2d(64, 5, kernel_size=(1, 1), padding=0),
nn.BatchNorm2d(5),
nn.ReLU(inplace=True),
nn.Conv2d(5, 1, kernel_size=(1, 1), padding=0),
)
if num_1d is not None:
self.final_1d = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=1, padding=0),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, num_1d, kernel_size=1, padding=0),
nn.Sigmoid(),
)
self.num_1d = num_1d
[docs] def forward(self, x):
"""Forward propagation of a batch."""
def run0(x, dummy):
lout1 = self.lconv1(x)
out1 = self.conv1(lout1)
lout2 = self.lconv2(out1 + lout1)
out2 = self.conv2(lout2)
lout3 = self.lconv3(out2 + lout2)
out3 = self.conv3(lout3)
lout4 = self.lconv4(out3 + lout3)
out4 = self.conv4(lout4)
lout5 = self.lconv5(out4 + lout4)
out5 = self.conv5(lout5)
lout6 = self.lconv6(out5 + lout5)
out6 = self.conv6(lout6)
lout7 = self.lconv7(out6 + lout6)
out7 = self.conv7(lout7)
mat = out7[:, :, :, None] + out7[:, :, None, :]
cur = mat
if self.num_1d:
output1d = self.final_1d(out7)
return cur, output1d
else:
return cur
dummy = torch.Tensor(1)
dummy.requires_grad = True
if self.num_1d:
cur, output1d = checkpoint(run0, x, dummy)
else:
cur = checkpoint(run0, x, dummy)
def run1(cur):
first = True
for lm, m in zip(self.lconvtwos[:7], self.convtwos[:7]):
if first:
cur = lm(cur)
first = False
else:
cur = lm(cur) + cur
cur = m(cur) + cur
return cur
def run2(cur):
for lm, m in zip(self.lconvtwos[7:13], self.convtwos[7:13]):
cur = lm(cur) + cur
cur = m(cur) + cur
return cur
def run3(cur):
for lm, m in zip(self.lconvtwos[13:], self.convtwos[13:]):
cur = lm(cur) + cur
cur = m(cur) + cur
cur = self.final(cur)
cur = 0.5 * cur + 0.5 * cur.transpose(2, 3)
return cur
cur = checkpoint(run1, cur)
cur = checkpoint(run2, cur)
cur = checkpoint(run3, cur)
if self.num_1d:
return cur, output1d
else:
return cur