Source code for orca_models

"""
This module contains the class definition of all Orca models.
For usage of the models, see the orca_predict module.
"""
import pathlib
import numpy as np

import torch
from torch import nn

from orca_modules import Encoder, Encoder2, Encoder3, Decoder, Decoder_1m, Net

ORCA_PATH = str(pathlib.Path(__file__).parent.absolute())


[docs]class H1esc(nn.Module): """ Orca H1-ESC model (1-32Mb) Attributes ---------- net0 : nn.DataParallel(Encoder) The first section of the multi-resolution encoder (bp resolution to 4kb resolution). net : nn.DataParallel(Encoder2) The second section of the multi-resolution encoder (4kb resolution to 128kb resolution). denets : dict(int: nn.DataParallel(Decoder)) Decoders at each level, which are stored in a dictionary with an integer as key. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self,): super(H1esc, self).__init__() modelstr = "h1esc" self.net = nn.DataParallel(Encoder2()) self.denet_1 = nn.DataParallel(Decoder()) self.denet_2 = nn.DataParallel(Decoder()) self.denet_4 = nn.DataParallel(Decoder()) self.denet_8 = nn.DataParallel(Decoder()) self.denet_16 = nn.DataParallel(Decoder()) self.denet_32 = nn.DataParallel(Decoder()) num_threads = torch.get_num_threads() self.net.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_1.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d1.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_2.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d2.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_4.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d4.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_8.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d8.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_16.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d16.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.denet_32.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d32.statedict", map_location=torch.device("cpu"), ), strict=True, ) self.net0 = nn.DataParallel(Encoder()) net0_dict = self.net0.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net0_dict} self.net0.load_state_dict(pretrained_dict_filtered) self.denet_1_pt = nn.DataParallel(Decoder_1m()) denet_1_pt_dict = self.denet_1_pt.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = { key: pretrained_dict["module." + key] for key in denet_1_pt_dict } self.denet_1_pt.load_state_dict(pretrained_dict_filtered) self.denet_1_pt.eval() self.net0.eval() self.net.eval() self.denet_1.eval() self.denet_2.eval() self.denet_4.eval() self.denet_8.eval() self.denet_16.eval() self.denet_32.eval() expected_log = np.load( ORCA_PATH + "/resources/4DNFI9GMP2J8.rebinned.mcool.expected.res4000.npy" ) normmat = np.exp(expected_log[np.abs(np.arange(8000)[None, :] - np.arange(8000)[:, None])]) normmat_r1 = np.reshape(normmat[:250, :250], (250, 1, 250, 1)).mean(axis=1).mean(axis=2) normmat_r2 = np.reshape(normmat[:500, :500], (250, 2, 250, 2)).mean(axis=1).mean(axis=2) normmat_r4 = np.reshape(normmat[:1000, :1000], (250, 4, 250, 4)).mean(axis=1).mean(axis=2) normmat_r8 = np.reshape(normmat[:2000, :2000], (250, 8, 250, 8)).mean(axis=1).mean(axis=2) normmat_r16 = ( np.reshape(normmat[:4000, :4000], (250, 16, 250, 16)).mean(axis=1).mean(axis=2) ) normmat_r32 = ( np.reshape(normmat[:8000, :8000], (250, 32, 250, 32)).mean(axis=1).mean(axis=2) ) eps1 = np.min(normmat_r1) eps2 = np.min(normmat_r2) eps4 = np.min(normmat_r4) eps8 = np.min(normmat_r8) eps16 = np.min(normmat_r16) eps32 = np.min(normmat_r32) self.normmats = { 1: normmat_r1, 2: normmat_r2, 4: normmat_r4, 8: normmat_r8, 16: normmat_r16, 32: normmat_r32, } self.epss = {1: eps1, 2: eps2, 4: eps4, 8: eps8, 16: eps16, 32: eps32} self.denets = { 1: self.denet_1, 2: self.denet_2, 4: self.denet_4, 8: self.denet_8, 16: self.denet_16, 32: self.denet_32, } torch.set_num_threads(num_threads)
[docs]class Hff(nn.Module): """ Orca HFF model (1-32Mb) Attributes ---------- net0 : nn.DataParallel(Encoder) The first section of the multi-resolution encoder (bp resolution to 4kb resolution). net : nn.DataParallel(Encoder2) The second section of the multi-resolution encoder (4kb resolution to 128kb resolution). denets : dict(int: nn.DataParallel(Decoder)) Decoders at each level, which are stored in a dictionary with an integer as key. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self): super(Hff, self).__init__() modelstr = "hff" self.net = nn.DataParallel(Encoder2()) self.denet_1 = nn.DataParallel(Decoder()) self.denet_2 = nn.DataParallel(Decoder()) self.denet_4 = nn.DataParallel(Decoder()) self.denet_8 = nn.DataParallel(Decoder()) self.denet_16 = nn.DataParallel(Decoder()) self.denet_32 = nn.DataParallel(Decoder()) num_threads = torch.get_num_threads() self.net.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_1.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d1.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_2.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d2.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_4.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d4.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_8.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d8.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_16.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d16.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.denet_32.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d32.statedict", map_location=torch.device("cpu"), ), strict=False, ) self.net0 = nn.DataParallel(Encoder()) net0_dict = self.net0.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net0_dict} self.net0.load_state_dict(pretrained_dict_filtered) self.denet_1_pt = nn.DataParallel(Decoder_1m()) denet_1_pt_dict = self.denet_1_pt.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = { key: pretrained_dict["module." + key] for key in denet_1_pt_dict } self.denet_1_pt.load_state_dict(pretrained_dict_filtered) self.denet_1_pt.eval() self.net0.eval() self.net.eval() self.denet_1.eval() self.denet_2.eval() self.denet_4.eval() self.denet_8.eval() self.denet_16.eval() self.denet_32.eval() expected_log = np.load( ORCA_PATH + "/resources/4DNFI643OYP9.rebinned.mcool.expected.res4000.npy" ) normmat = np.exp(expected_log[np.abs(np.arange(8000)[:, None] - np.arange(8000)[None, :])]) normmat_r1 = np.reshape(normmat[:250, :250], (250, 1, 250, 1)).mean(axis=1).mean(axis=2) normmat_r2 = np.reshape(normmat[:500, :500], (250, 2, 250, 2)).mean(axis=1).mean(axis=2) normmat_r4 = np.reshape(normmat[:1000, :1000], (250, 4, 250, 4)).mean(axis=1).mean(axis=2) normmat_r8 = np.reshape(normmat[:2000, :2000], (250, 8, 250, 8)).mean(axis=1).mean(axis=2) normmat_r16 = ( np.reshape(normmat[:4000, :4000], (250, 16, 250, 16)).mean(axis=1).mean(axis=2) ) normmat_r32 = ( np.reshape(normmat[:8000, :8000], (250, 32, 250, 32)).mean(axis=1).mean(axis=2) ) eps1 = np.min(normmat_r1) eps2 = np.min(normmat_r2) eps4 = np.min(normmat_r4) eps8 = np.min(normmat_r8) eps16 = np.min(normmat_r16) eps32 = np.min(normmat_r32) self.normmats = { 1: normmat_r1, 2: normmat_r2, 4: normmat_r4, 8: normmat_r8, 16: normmat_r16, 32: normmat_r32, } self.epss = {1: eps1, 2: eps2, 4: eps4, 8: eps8, 16: eps16, 32: eps32} self.denets = { 1: self.denet_1, 2: self.denet_2, 4: self.denet_4, 8: self.denet_8, 16: self.denet_16, 32: self.denet_32, } torch.set_num_threads(num_threads)
[docs]class HCTnoc(nn.Module): """ Orca cohesin-depleted HCT116 model (1-32Mb) Attributes ---------- net0 : nn.DataParallel(Encoder) The first section of the multi-resolution encoder (bp resolution to 4kb resolution). net : nn.DataParallel(Encoder2) The second section of the multi-resolution encoder (4kb resolution to 128kb resolution). denets : dict(int: nn.DataParallel(Decoder)) Decoders at each level, which are stored in a dictionary with an integer as key. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self): super(HCTnoc, self).__init__() modelstr = "hctnoc" self.net = nn.DataParallel(Encoder2()) self.denet_1 = nn.DataParallel(Decoder()) self.denet_2 = nn.DataParallel(Decoder()) self.denet_4 = nn.DataParallel(Decoder()) self.denet_8 = nn.DataParallel(Decoder()) self.denet_16 = nn.DataParallel(Decoder()) self.denet_32 = nn.DataParallel(Decoder()) self.net.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".net.statedict"), strict=True ) self.denet_1.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d1.statedict"), strict=True ) self.denet_2.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d2.statedict"), strict=True ) self.denet_4.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d4.statedict"), strict=True ) self.denet_8.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d8.statedict"), strict=True ) self.denet_16.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d16.statedict"), strict=True ) self.denet_32.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".d32.statedict"), strict=True ) self.net0 = nn.DataParallel(Encoder()) self.net0.load_state_dict( torch.load(ORCA_PATH + "/models/orca_" + modelstr + ".net0.statedict"), strict=True ) self.net0.cuda() self.net0.eval() self.net.eval() self.denet_1.eval() self.denet_2.eval() self.denet_4.eval() self.denet_8.eval() self.denet_16.eval() self.denet_32.eval() smooth_diag = np.load( ORCA_PATH + "/resources/4DNFILP99QJS.HCT_auxin6h.rebinned.mcool.expected.res4000.npy" ) normmat = np.exp(smooth_diag[np.abs(np.arange(8000)[None, :] - np.arange(8000)[:, None])]) normmat_r1 = np.reshape(normmat[:250, :250], (250, 1, 250, 1)).mean(axis=1).mean(axis=2) normmat_r2 = np.reshape(normmat[:500, :500], (250, 2, 250, 2)).mean(axis=1).mean(axis=2) normmat_r4 = np.reshape(normmat[:1000, :1000], (250, 4, 250, 4)).mean(axis=1).mean(axis=2) normmat_r8 = np.reshape(normmat[:2000, :2000], (250, 8, 250, 8)).mean(axis=1).mean(axis=2) normmat_r16 = ( np.reshape(normmat[:4000, :4000], (250, 16, 250, 16)).mean(axis=1).mean(axis=2) ) normmat_r32 = ( np.reshape(normmat[:8000, :8000], (250, 32, 250, 32)).mean(axis=1).mean(axis=2) ) eps1 = np.min(normmat_r1) eps2 = np.min(normmat_r2) eps4 = np.min(normmat_r4) eps8 = np.min(normmat_r8) eps16 = np.min(normmat_r16) eps32 = np.min(normmat_r32) self.normmats = { 1: normmat_r1, 2: normmat_r2, 4: normmat_r4, 8: normmat_r8, 16: normmat_r16, 32: normmat_r32, } self.epss = {1: eps1, 2: eps2, 4: eps4, 8: eps8, 16: eps16, 32: eps32} self.denets = { 1: self.denet_1, 2: self.denet_2, 4: self.denet_4, 8: self.denet_8, 16: self.denet_16, 32: self.denet_32, }
[docs]class H1esc_1M(nn.Module): """ Orca H1-ESC model (1Mb) Attributes ---------- net : nn.DataParallel(Net) Integrated Encoder and Decoder for 1Mb model. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self,): super(H1esc_1M, self).__init__() self.net = nn.DataParallel(Net(num_1d=32)) num_threads = torch.get_num_threads() net_dict = self.net.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_h1esc.net0.statedict", map_location=torch.device("cpu") ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net_dict} self.net.load_state_dict(pretrained_dict_filtered) self.net.eval() expected_log = np.load( ORCA_PATH + "/resources/4DNFI9GMP2J8.rebinned.mcool.expected.res1000.npy" )[:1000] normmat = np.exp(expected_log[np.abs(np.arange(1000)[None, :] - np.arange(1000)[:, None])]) normmat_r = np.reshape(normmat, (250, 4, 250, 4)).mean(axis=1).mean(axis=2) eps = np.min(normmat_r) self.normmats = {1: normmat_r} self.epss = {1: eps} torch.set_num_threads(num_threads)
[docs] def forward(self, x): pred, _ = self.net.forward(x) return pred
[docs]class Hff_1M(nn.Module): """ Orca HFF model (1Mb) Attributes ---------- net : nn.DataParallel(Net) Integrated Encoder and Decoder for 1Mb model. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self,): super(Hff_1M, self).__init__() self.net = nn.DataParallel(Net(num_1d=22)) num_threads = torch.get_num_threads() net_dict = self.net.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_hff.net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net_dict} self.net.load_state_dict(pretrained_dict_filtered) self.net.eval() expected = np.exp( np.load(ORCA_PATH + "/resources/4DNFI643OYP9.rebinned.mcool.expected.res1000.npy")[ :1000 ] ) normmat = expected[np.abs(np.arange(1000)[:, None] - np.arange(1000)[None, :])] normmat_r = np.reshape(normmat, (250, 4, 250, 4)).mean(axis=1).mean(axis=2) eps = np.min(normmat_r) self.normmats = {1: normmat_r} self.epss = {1: eps} torch.set_num_threads(num_threads)
[docs] def forward(self, x): pred, _ = self.net.forward(x) return pred
[docs]class H1esc_256M(nn.Module): """ Orca H1-ESC model (32-256Mb) Attributes ---------- net0 : nn.DataParallel(Encoder) The first section of the multi-resolution encoder (bp resolution to 4kb resolution). net1 : nn.DataParallel(Encoder2) The second section of the multi-resolution encoder (4kb resolution to 128kb resolution). net : nn.DataParallel(Encoder3) The third section of the multi-resolution encoder (128kb resolution to 1024kb resolution). denets : dict(int: nn.DataParallel(Decoder)) Decoders at each level, which are stored in a dictionary with an integer as key. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self,): super(H1esc_256M, self).__init__() modelstr = "h1esc_256m" self.net = nn.DataParallel(Encoder3()) self.denet_32 = nn.DataParallel(Decoder()) self.denet_64 = nn.DataParallel(Decoder()) self.denet_128 = nn.DataParallel(Decoder()) self.denet_256 = nn.DataParallel(Decoder()) num_threads = torch.get_num_threads() self.net.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net.statedict", map_location=torch.device("cpu"), ) ) self.denet_32.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d32.statedict", map_location=torch.device("cpu"), ) ) self.denet_64.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d64.statedict", map_location=torch.device("cpu"), ) ) self.denet_128.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d128.statedict", map_location=torch.device("cpu"), ) ) self.denet_256.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d256.statedict", map_location=torch.device("cpu"), ) ) self.net0 = nn.DataParallel(Encoder()) net0_dict = self.net0.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_h1esc.net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net0_dict} self.net0.load_state_dict(pretrained_dict_filtered) self.net1 = nn.DataParallel(Encoder2()) net1_dict = self.net1.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_h1esc.net.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict[key] for key in net1_dict} self.net1.load_state_dict(pretrained_dict_filtered) self.net0.eval() self.net1.eval() self.net.eval() self.denet_32.eval() self.denet_64.eval() self.denet_128.eval() self.denet_256.eval() self.background_cis = np.load( ORCA_PATH + "/resources/4DNFI9GMP2J8.rebinned.mcool.expected.res32000.mono.npy" ) self.background_trans = np.load( ORCA_PATH + "/resources/4DNFI9GMP2J8.rebinned.mcool.expected.res32000.trans.npy" ) self.background_cis = np.hstack([np.exp(self.background_cis), np.repeat(np.nan, 2000)]) self.background_trans = np.exp(self.background_trans) self.denets = { 32: self.denet_32, 64: self.denet_64, 128: self.denet_128, 256: self.denet_256, } torch.set_num_threads(num_threads)
[docs]class Hff_256M(nn.Module): """ Orca HFF model (32-256Mb) Attributes ---------- net0 : nn.DataParallel(Encoder) The first section of the multi-resolution encoder (bp resolution to 4kb resolution). net1 : nn.DataParallel(Encoder2) The second section of the multi-resolution encoder (4kb resolution to 128kb resolution). net : nn.DataParallel(Encoder3) The third section of the multi-resolution encoder (128kb resolution to 1024kb resolution). denets : dict(int: nn.DataParallel(Decoder)) Decoders at each level, which are stored in a dictionary with an integer as key. normmats : dict(int: numpy.ndarray) The distance-based background matrices with expected log fold over background values at each level. epss : dict(int: float) The minimum background value at each level. Used for stablizing the log fold computation by adding to both the nominator and the denominator. """ def __init__(self): super(Hff_256M, self).__init__() modelstr = "hff_256m" self.net = nn.DataParallel(Encoder3()) self.denet_32 = nn.DataParallel(Decoder()) self.denet_64 = nn.DataParallel(Decoder()) self.denet_128 = nn.DataParallel(Decoder()) self.denet_256 = nn.DataParallel(Decoder()) num_threads = torch.get_num_threads() self.net.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".net.statedict", map_location=torch.device("cpu"), ) ) self.denet_32.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d32.statedict", map_location=torch.device("cpu"), ) ) self.denet_64.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d64.statedict", map_location=torch.device("cpu"), ) ) self.denet_128.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d128.statedict", map_location=torch.device("cpu"), ) ) self.denet_256.load_state_dict( torch.load( ORCA_PATH + "/models/orca_" + modelstr + ".d256.statedict", map_location=torch.device("cpu"), ) ) self.net0 = nn.DataParallel(Encoder()) net0_dict = self.net0.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_hff.net0.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict["module." + key] for key in net0_dict} self.net0.load_state_dict(pretrained_dict_filtered) self.net1 = nn.DataParallel(Encoder2()) net1_dict = self.net1.state_dict() pretrained_dict = torch.load( ORCA_PATH + "/models/orca_hff.net.statedict", map_location=torch.device("cpu"), ) pretrained_dict_filtered = {key: pretrained_dict[key] for key in net1_dict} self.net1.load_state_dict(pretrained_dict_filtered) self.net0.eval() self.net1.eval() self.net.eval() self.denet_32.eval() self.denet_64.eval() self.denet_128.eval() self.denet_256.eval() self.background_cis = np.load( ORCA_PATH + "/resources/4DNFI643OYP9.rebinned.mcool.expected.res32000.mono.npy" ) self.background_trans = np.load( ORCA_PATH + "/resources/4DNFI643OYP9.rebinned.mcool.expected.res32000.trans.npy" ) self.background_cis = np.hstack([np.exp(self.background_cis), np.repeat(np.nan, 2000)]) self.background_trans = np.exp(self.background_trans) self.denets = { 32: self.denet_32, 64: self.denet_64, 128: self.denet_128, 256: self.denet_256, } torch.set_num_threads(num_threads)