-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathflow.py
76 lines (50 loc) · 1.91 KB
/
flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from utils import safe_log
class NormalizingFlow(nn.Module):
def __init__(self, dim, flow_length):
super().__init__()
self.transforms = nn.Sequential(*(
PlanarFlow(dim) for _ in range(flow_length)
))
self.log_jacobians = nn.Sequential(*(
PlanarFlowLogDetJacobian(t) for t in self.transforms
))
def forward(self, z):
log_jacobians = []
for transform, log_jacobian in zip(self.transforms, self.log_jacobians):
log_jacobians.append(log_jacobian(z))
z = transform(z)
zk = z
return zk, log_jacobians
class PlanarFlow(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.Tensor(1, dim))
self.bias = nn.Parameter(torch.Tensor(1))
self.scale = nn.Parameter(torch.Tensor(1, dim))
self.tanh = nn.Tanh()
self.reset_parameters()
def reset_parameters(self):
self.weight.data.uniform_(-0.01, 0.01)
self.scale.data.uniform_(-0.01, 0.01)
self.bias.data.uniform_(-0.01, 0.01)
def forward(self, z):
activation = F.linear(z, self.weight, self.bias)
return z + self.scale * self.tanh(activation)
class PlanarFlowLogDetJacobian(nn.Module):
"""A helper class to compute the determinant of the gradient of
the planar flow transformation."""
def __init__(self, affine):
super().__init__()
self.weight = affine.weight
self.bias = affine.bias
self.scale = affine.scale
self.tanh = affine.tanh
def forward(self, z):
activation = F.linear(z, self.weight, self.bias)
psi = (1 - self.tanh(activation) ** 2) * self.weight
det_grad = 1 + torch.mm(psi, self.scale.t())
return safe_log(det_grad.abs())