温州做美食网站十大搜索引擎排行榜
#这两行导入了PyTorch和NumPy库,分别用于深度学习和数值计算
import torch
import numpy as np
#这两行导入了PyTorch的神经网络模块和函数模块。
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List, Tuple
import math
from functools import partial
from torch import nn, einsum, diagonal
from math import log2, ceil
import pdb
from sympy import Poly, legendre, Symbol, chebyshevt
from scipy.special import eval_legendre
#这个函数计算Legendre多项式的导数。
def legendreDer(k, x):
def _legendre(k, x):
return (2 * k + 1) * eval_legendre(k, x)
#初始化输出值为0。
out = 0
#开启循环,以步长为-2遍历奇数或偶数Legendre级数。
for i in np.arange(k - 1, -1, -2):
#将Legendre多项式的结果累加到输出中。
out += _legendre(i, x)
return out
#这个函数定义了多项式基函数,并应用上下限进行裁剪
def phi_(phi_c, x, lb=0, ub=1):
#创建一个掩码,用于指示x值是否在定义的区间外。
mask = np.logical_or(x < lb, x > ub) * 1.0
#如果x在定义的区间内,则返回多项式的值;如果在区间外,则返回0
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
#这个函数计算多项式基函数和关于基base的检验函数。
def get_phi_psi(k, base):
x = Symbol('x')
phi_coeff = np.zeros((k, k))
phi_2x_coeff = np.zeros((k, k))
#以下代码计算Legendre和切比雪夫(Chebyshev)多项式,并且它们的系数。这些多项式用于构建多小波基。
if base == 'legendre':
for ki in range(k):
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
a = phi_2x_coeff[ki, :ki + 1]
b = phi_coeff[i, :i + 1]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
a = phi_2x_coeff[ki, :ki + 1]
b = psi1_coeff[j, :]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
a = psi1_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
a = psi2_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
elif base == 'chebyshev':
for ki in range(k):
if ki == 0:
phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi)
phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
else:
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, :ki + 1] = np.flip(
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
x = Symbol('x')
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
psi1 = [[] for _ in range(k)]
psi2 = [[] for _ in range(k)]
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
return phi, psi1, psi2
def get_filter(base, k):
def psi(psi1, psi2, i, inp):
mask = (inp <= 0.5) * 1.0
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
if base not in ['legendre', 'chebyshev']:
raise Exception('Base not supported')
x = Symbol('x')
H0 = np.zeros((k, k))
H1 = np.zeros((k, k))
G0 = np.zeros((k, k))
G1 = np.zeros((k, k))
PHI0 = np.zeros((k, k))
PHI1 = np.zeros((k, k))
phi, psi1, psi2 = get_phi_psi(k, base)
if base == 'legendre':
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0 = np.eye(k)
PHI1 = np.eye(k)
elif base == 'chebyshev':
x = Symbol('x')
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
PHI0[np.abs(PHI0) < 1e-8] = 0
PHI1[np.abs(PHI1) < 1e-8] = 0
H0[np.abs(H0) < 1e-8] = 0
H1[np.abs(H1) < 1e-8] = 0
G0[np.abs(G0) < 1e-8] = 0
G1[np.abs(G1) < 1e-8] = 0
return H0, H1, G0, G1, PHI0, PHI1
#定义了一个基于多小波变换的神经网络模块。这个类执行多小波变换,用于深度学习模型中的特征变换。
class MultiWaveletTransform(nn.Module):
"""
1D multiwavelet block.
"""
def __init__(self, ich=1, k=8, alpha=16, c=128,
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
super(MultiWaveletTransform, self).__init__()
print('base', base)
self.k = k
self.c = c
self.L = L
self.nCZ = nCZ
self.Lk0 = nn.Linear(ich, c * k)
self.Lk1 = nn.Linear(c * k, ich)
self.ich = ich
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
values = values.view(B, L, -1)
V = self.Lk0(values).view(B, L, self.c, -1)
for i in range(self.nCZ):
V = self.MWT_CZ[i](V)
if i < self.nCZ - 1:
V = F.relu(V)
V = self.Lk1(V.view(B, L, -1))
V = V.view(B, L, -1, D)
return (V.contiguous(), None)
#定义了一个基于多小波变换的交叉注意力模块。这个类结合了交叉注意力和多小波变换,用于捕捉序列间的关系。
class MultiWaveletCross(nn.Module):
"""
1D Multiwavelet Cross Attention layer.
"""
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64,
k=8, ich=512,
L=0,
base='legendre',
mode_select_method='random',
initializer=None, activation='tanh',
**kwargs):
super(MultiWaveletCross, self).__init__()
print('base', base)
self.c = c
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.T0 = nn.Linear(k, k)
self.register_buffer('ec_s', torch.Tensor(
np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer('ec_d', torch.Tensor(
np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer('rc_e', torch.Tensor(
np.concatenate((H0r, G0r), axis=0)))
self.register_buffer('rc_o', torch.Tensor(
np.concatenate((H1r, G1r), axis=0)))
self.Lk = nn.Linear(ich, c * k)
self.Lq = nn.Linear(ich, c * k)
self.Lv = nn.Linear(ich, c * k)
self.out = nn.Linear(c * k, ich)
self.modes1 = modes
def forward(self, q, k, v, mask=None):
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
v = v.view(v.shape[0], v.shape[1], -1)
q = self.Lq(q)
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
k = self.Lk(k)
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
v = self.Lv(v)
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
if N > S:
zeros = torch.zeros_like(q[:, :(N - S), :]).float()
v = torch.cat([v, zeros], dim=1)
k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :N, :, :]
k = k[:, :N, :, :]
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_q = q[:, 0:nl - N, :, :]
extra_k = k[:, 0:nl - N, :, :]
extra_v = v[:, 0:nl - N, :, :]
q = torch.cat([q, extra_q], 1)
k = torch.cat([k, extra_k], 1)
v = torch.cat([v, extra_v], 1)
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
Us_q = torch.jit.annotate(List[Tensor], [])
Us_k = torch.jit.annotate(List[Tensor], [])
Us_v = torch.jit.annotate(List[Tensor], [])
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns - self.L):
# print('q shape',q.shape)
d, q = self.wavelet_transform(q)
Ud_q += [tuple([d, q])]
Us_q += [d]
for i in range(ns - self.L):
d, k = self.wavelet_transform(k)
Ud_k += [tuple([d, k])]
Us_k += [d]
for i in range(ns - self.L):
d, v = self.wavelet_transform(v)
Ud_v += [tuple([d, v])]
Us_v += [d]
for i in range(ns - self.L):
dk, sk = Ud_k[i], Us_k[i]
dq, sq = Ud_q[i], Us_q[i]
dv, sv = Ud_v[i], Us_v[i]
Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]]
Us += [self.attn3(sq, sk, sv, mask)[0]]
v = self.attn4(q, k, v, mask)[0]
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
v = v + Us[i]
v = torch.cat((v, Ud[i]), -1)
v = self.evenOdd(v)
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
return (v.contiguous(), None)
def wavelet_transform(self, x):
xa = torch.cat([x[:, ::2, :, :],
x[:, 1::2, :, :],
], -1)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k,
device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
#定义了一个基于傅里叶变换的交叉注意力模块。这个类执行交叉注意力操作,在傅里叶域内增强序列间的特征表示。
class FourierCrossAttentionW(nn.Module):
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
mode_select_method='random'):
super(FourierCrossAttentionW, self).__init__()
print('corss fourier correlation used!')
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes
self.activation = activation
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
B, L, E, H = q.shape
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
xk = k.permute(0, 3, 2, 1)
xv = v.permute(0, 3, 2, 1)
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_k_v):
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
if self.activation == 'tanh':
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == 'softmax':
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception('{} actiation function is not implemented'.format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = xqkv_ft
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
# size = [B, L, H, E]
return (out, None)
#定义了一个执行稀疏核傅里叶变换的神经网络模块。这个类使用参数化的稀疏核在傅里叶域内执行特征变换。
class sparseKernelFT1d(nn.Module):
def __init__(self,
k, alpha, c=1,
nl=1,
initializer=None,
**kwargs):
super(sparseKernelFT1d, self).__init__()
self.modes1 = alpha
self.scale = (1 / (c * k * c * k))
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights1.requires_grad = True
self.weights2.requires_grad = True
self.k = k
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, x):
B, N, c, k = x.shape # (B, N, c, k)
x = x.view(B, N, -1)
x = x.permute(0, 2, 1)
x_fft = torch.fft.rfft(x)
# Multiply relevant Fourier modes
l = min(self.modes1, N // 2 + 1)
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l],
torch.complex(self.weights1, self.weights2)[:, :, :l])
x = torch.fft.irfft(out_ft, n=N)
x = x.permute(0, 2, 1).view(B, N, c, k)
return x
# ##
#定义了一个执行多小波变换卷积操作的神经网络模块。这个类结合了多小波变换和卷积操作,用于深度学习模型中的特征提取。
class MWT_CZ1d(nn.Module):
def __init__(self,
k=3, alpha=64,
L=0, c=1,
base='legendre',
initializer=None,
**kwargs):
super(MWT_CZ1d, self).__init__()
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.A = sparseKernelFT1d(k, alpha, c)
self.B = sparseKernelFT1d(k, alpha, c)
self.C = sparseKernelFT1d(k, alpha, c)
self.T0 = nn.Linear(k, k)
self.register_buffer('ec_s', torch.Tensor(
np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer('ec_d', torch.Tensor(
np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer('rc_e', torch.Tensor(
np.concatenate((H0r, G0r), axis=0)))
self.register_buffer('rc_o', torch.Tensor(
np.concatenate((H1r, G1r), axis=0)))
def forward(self, x):
B, N, c, k = x.shape # (B, N, k)
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_x = x[:, 0:nl - N, :, :]
x = torch.cat([x, extra_x], 1)
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
for i in range(ns - self.L):
d, x = self.wavelet_transform(x)
Ud += [self.A(d) + self.B(x)]
Us += [self.C(d)]
x = self.T0(x) # coarsest scale transform
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
x = x + Us[i]
x = torch.cat((x, Ud[i]), -1)
x = self.evenOdd(x)
x = x[:, :N, :, :]
return x
def wavelet_transform(self, x):
xa = torch.cat([x[:, ::2, :, :],
x[:, 1::2, :, :],
], -1)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k,
device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x