Skip to content

Commit 9b8e5ee

Browse files
authored
changes per Patrik's comments (#1285)
* changes per Patrik's comments * update conversion script
1 parent becc803 commit 9b8e5ee

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

scripts/convert_models_diffuser_to_diffusers.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ def unet(hor):
2929
block_out_channels=block_out_channels,
3030
up_block_types=up_block_types,
3131
layers_per_block=1,
32+
use_timestep_embedding=True,
33+
out_block_type="OutConv1DBlock",
34+
norm_num_groups=8,
35+
downsample_each_block=False,
36+
in_channels=14,
37+
out_channels=14,
38+
extra_in_channels=0,
39+
time_embedding_type="positional",
40+
flip_sin_to_cos=False,
41+
freq_shift=1,
42+
sample_size=65536,
43+
mid_block_type="MidResTemporalBlock1D",
44+
act_fn="mish",
3245
)
3346
hf_value_function = UNet1DModel(**config)
3447
print(f"length of state dict: {len(state_dict.keys())}")
@@ -52,7 +65,16 @@ def value_function():
5265
mid_block_type="ValueFunctionMidBlock1D",
5366
block_out_channels=(32, 64, 128, 256),
5467
layers_per_block=1,
55-
always_downsample=True,
68+
downsample_each_block=True,
69+
sample_size=65536,
70+
out_channels=14,
71+
extra_in_channels=0,
72+
time_embedding_type="positional",
73+
use_timestep_embedding=True,
74+
flip_sin_to_cos=False,
75+
freq_shift=1,
76+
norm_num_groups=8,
77+
act_fn="mish",
5678
)
5779

5880
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu",
6969
self.act = None
7070
if act_fn == "silu":
7171
self.act = nn.SiLU()
72-
if act_fn == "mish":
72+
elif act_fn == "mish":
7373
self.act = nn.Mish()
7474

7575
if out_dim is not None:

src/diffusers/models/resnet.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,9 @@ def forward(self, x):
523523
class ResidualTemporalBlock1D(nn.Module):
524524
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
525525
super().__init__()
526+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
527+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
526528

527-
self.blocks = nn.ModuleList(
528-
[
529-
Conv1dBlock(inp_channels, out_channels, kernel_size),
530-
Conv1dBlock(out_channels, out_channels, kernel_size),
531-
]
532-
)
533529
self.time_emb_act = nn.Mish()
534530
self.time_emb = nn.Linear(embed_dim, out_channels)
535531

@@ -548,8 +544,8 @@ def forward(self, x, t):
548544
"""
549545
t = self.time_emb_act(t)
550546
t = self.time_emb(t)
551-
out = self.blocks[0](x) + rearrange_dims(t)
552-
out = self.blocks[1](out)
547+
out = self.conv_in(x) + rearrange_dims(t)
548+
out = self.conv_out(out)
553549
return out + self.residual_conv(x)
554550

555551

src/diffusers/models/unet_1d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
time_embedding_type: str = "fourier",
7878
flip_sin_to_cos: bool = True,
7979
use_timestep_embedding: bool = False,
80-
downscale_freq_shift: float = 0.0,
80+
freq_shift: float = 0.0,
8181
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
8282
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
8383
mid_block_type: Tuple[str] = "UNetMidBlock1D",
@@ -86,7 +86,7 @@ def __init__(
8686
act_fn: str = None,
8787
norm_num_groups: int = 8,
8888
layers_per_block: int = 1,
89-
always_downsample: bool = False,
89+
downsample_each_block: bool = False,
9090
):
9191
super().__init__()
9292
self.sample_size = sample_size
@@ -99,7 +99,7 @@ def __init__(
9999
timestep_input_dim = 2 * block_out_channels[0]
100100
elif time_embedding_type == "positional":
101101
self.time_proj = Timesteps(
102-
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift
102+
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
103103
)
104104
timestep_input_dim = block_out_channels[0]
105105

@@ -134,7 +134,7 @@ def __init__(
134134
in_channels=input_channel,
135135
out_channels=output_channel,
136136
temb_channels=block_out_channels[0],
137-
add_downsample=not is_final_block or always_downsample,
137+
add_downsample=not is_final_block or downsample_each_block,
138138
)
139139
self.down_blocks.append(down_block)
140140

@@ -146,7 +146,7 @@ def __init__(
146146
out_channels=block_out_channels[-1],
147147
embed_dim=block_out_channels[0],
148148
num_layers=layers_per_block,
149-
add_downsample=always_downsample,
149+
add_downsample=downsample_each_block,
150150
)
151151

152152
# up

0 commit comments

Comments
 (0)