Skip to content
24 changes: 24 additions & 0 deletions tests/test_alignment_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch_struct
import pytest


@pytest.mark.skipif(not torch.cuda.is_available(), reason='needs CUDA')
def test_alignment_crf_shapes():
batch, N, M = 2, 4, 5
log_potentials = torch.rand(batch, N, M, 3).cuda()

dist = torch_struct.AlignmentCRF(log_potentials)
assert (batch, N, M, 3) == dist.argmax.shape
assert (batch, N, M, 3) == dist.marginals.shape
assert (batch,) == dist.partition.shape

# Fail due to AttributeError: 'BandedMatrix' object has no attribute
# 'unsqueeze'
assert (batch,) == dist.entropy.shape
# assert (9, batch, N, M, 3) == dist.sample([9]).shape

# Fails due to: RuntimeError: Expected condition, x and y to be on
# the same device, but condition is on cpu and x and y are on
# cuda:0 and cuda:0 respectively
# assert (8, batch,) == dist.topk(8).shape
14 changes: 14 additions & 0 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from hypothesis import given
from hypothesis.strategies import integers
import genbmm

bint = integers(min_value=1, max_value=4)
mint = integers(min_value=6, max_value=8)
nint = integers(min_value=3, max_value=5)
kint = integers(min_value=9, max_value=11)


@given(bint, mint, nint, kint)
def test_matmul(batch, m, n, k):
a, b = torch.rand((m, n)), torch.rand((n, k))
3 changes: 2 additions & 1 deletion tests/test_semirings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from hypothesis import given
from hypothesis import given, settings
from hypothesis.strategies import integers


Expand All @@ -17,6 +17,7 @@


@given(lint, lint, lint)
@settings(deadline=None) # Avoid spurious warnings when first run
def test_max(a, b, c):
torch.manual_seed(0)
t1 = torch.rand(a, 1, c).requires_grad_(True)
Expand Down
11 changes: 7 additions & 4 deletions torch_struct/alignment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from .helpers import _Struct
import math
import warnings

try:
import genbmm

except ImportError:
pass
warnings.warn('Could not import genbmm. '
'However, genbmm is only used for CUDA operations.')

from .semirings import LogSemiring
from .semirings.fast_semirings import broadcast
Expand Down Expand Up @@ -97,9 +100,9 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
# Create finalizing paths.
point = (l + M) // 2

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
init = torch.zeros(charta[1].shape, device=charta[1].device).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)

for b in range(lengths.shape[0]):
point = (lengths[b] + M) // 2
Expand Down
3 changes: 3 additions & 0 deletions torch_struct/semirings/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def forward(ctx, input, dim):
def backward(ctx, grad_output):

logits, part, dim = ctx.saved_tensors
# Replace infinite logits with max float, otherwise softmax gives NaNs
# Perhaps this could be done earlier (during forward pass)?
logits[logits == float('inf')] = torch.finfo(logits.dtype).max
grad_input = None
if ctx.needs_input_grad[0]:

Expand Down
3 changes: 1 addition & 2 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class LogSemiring(_BaseLog):

Gradients give marginals.
"""

@classmethod
def matmul(cls, a, b):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
Expand Down Expand Up @@ -192,7 +191,7 @@ def convert(cls, orig_potentials):
dtype=orig_potentials.dtype,
device=orig_potentials.device,
)
potentials = cls.fill(potentials, torch.tensor(True), cls.zero)
potentials = cls.fill(potentials, torch.tensor(True, device=potentials.device), cls.zero.to(potentials.device))
potentials[0] = orig_potentials
return potentials

Expand Down