Skip to content

Commit 8c2e42d

Browse files
Refactored project code
fixed the entropy model and improved training speed by using TFRecords format for the data set, a script is included to construct the records from any image dataset
1 parent 57e93ab commit 8c2e42d

File tree

7 files changed

+564
-0
lines changed

7 files changed

+564
-0
lines changed

LPIPS.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import tensorflow as tf
3+
import urllib.request
4+
5+
_LPIPS_URL = "http://rail.eecs.berkeley.edu/models/lpips/net-lin_alex_v0.1.pb"
6+
7+
8+
def ensure_lpips_weights_exist(weight_path_out):
9+
"""Downloads weights if needed."""
10+
if os.path.isfile(weight_path_out):
11+
return
12+
print("Downloading LPIPS weights:", _LPIPS_URL, "->", weight_path_out)
13+
urllib.request.urlretrieve(_LPIPS_URL, weight_path_out)
14+
if not os.path.isfile(weight_path_out):
15+
raise ValueError(f"Failed to download LPIPS weights from {_LPIPS_URL} "
16+
f"to {weight_path_out}. Please manually download!")
17+
18+
19+
class LPIPSLoss(object):
20+
"""Calcualte LPIPS loss."""
21+
22+
def __init__(self, weight_path):
23+
ensure_lpips_weights_exist(weight_path)
24+
25+
def wrap_frozen_graph(graph_def, inputs, outputs):
26+
def _imports_graph_def():
27+
tf.graph_util.import_graph_def(graph_def, name="")
28+
29+
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
30+
import_graph = wrapped_import.graph
31+
return wrapped_import.prune(
32+
tf.nest.map_structure(import_graph.as_graph_element, inputs),
33+
tf.nest.map_structure(import_graph.as_graph_element, outputs))
34+
35+
# Pack LPIPS network into a tf function
36+
graph_def = tf.compat.v1.GraphDef()
37+
with open(weight_path, "rb") as f:
38+
graph_def.ParseFromString(f.read())
39+
self._lpips_func = tf.function(
40+
wrap_frozen_graph(
41+
graph_def, inputs=("0:0", "1:0"), outputs="Reshape_10:0"))
42+
43+
def __call__(self, fake_image, real_image):
44+
"""Assuming inputs are in [0, 1]."""
45+
46+
# Move inputs to [-1, 1] and NCHW format.
47+
def _transpose_to_nchw(x):
48+
return tf.transpose(x, (0, 3, 1, 2))
49+
50+
fake_image = _transpose_to_nchw(fake_image * 2 - 1.0)
51+
real_image = _transpose_to_nchw(real_image * 2 - 1.0)
52+
loss = self._lpips_func(fake_image, real_image)
53+
return tf.reduce_mean(loss) # Loss is N111, take mean to get scalar.

compress.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import tensorflow as tf
3+
import tensorflow_compression as tfc
4+
import argparse
5+
from glob import glob
6+
7+
##process the image into suitable dimension
8+
def load_img(path):
9+
string = tf.io.read_file(path)
10+
image = tf.image.decode_image(string, channels=3)
11+
return image
12+
13+
14+
def load_model(args):
15+
model = tf.keras.models.load_model(args.model_path,compile=False)
16+
return model
17+
18+
19+
def compress(args):
20+
model = load_model(args)
21+
22+
23+
os.makedirs('outputs/binary', exist_ok=True)
24+
25+
if os.path.isdir(args.image_path):
26+
pathes = glob(os.path.join(args.image_path, '*'))
27+
else:
28+
pathes = [args.image_path]
29+
30+
for path in pathes:
31+
bitpath = "outputs/binary/{}.pth".format(os.path.basename(path).split('.')[0])
32+
33+
image = load_img(path)
34+
compressed = model.compress(image)
35+
packed = tfc.PackedTensors()
36+
packed.pack(compressed)
37+
with open(bitpath, "wb") as f:
38+
f.write(packed.string)
39+
num_pixels = tf.reduce_prod(tf.shape(image)[:-1])
40+
bpp = len(packed.string) * 8 / num_pixels
41+
42+
43+
44+
45+
print('=============================================================')
46+
print(os.path.basename(path))
47+
48+
print('bitrate : {0:.4}bpp'.format(bpp))
49+
print('=============================================================\n')
50+
51+
52+
if __name__ == "__main__":
53+
parser = argparse.ArgumentParser()
54+
parser.add_argument('model_path',type=str, default='final_model')
55+
parser.add_argument('image_path',type=str, default='kodak/kodim20.png')
56+
57+
args = parser.parse_args()
58+
59+
compress(args)

decompress.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import tensorflow as tf
3+
import tensorflow_compression as tfc
4+
import argparse
5+
from glob import glob
6+
7+
def load_model(args):
8+
model = tf.keras.models.load_model(args.model_path,compile=False)
9+
return model
10+
11+
def decompress(model, args):
12+
os.makedirs("outputs/reconstruction/", exist_ok=True)
13+
14+
if os.path.isdir(args.binary_path):
15+
pathes = glob(os.path.join(args.binary_path, '*'))
16+
else:
17+
pathes = [args.binary_path]
18+
19+
for path in pathes:
20+
21+
print('========================================================================')
22+
print('image', os.path.basename(path))
23+
24+
with open(path, "rb") as f:
25+
packed = tfc.PackedTensors(f.read())
26+
tensors = packed.unpack(dtypes)
27+
x_hat = model.decompress(*tensors)
28+
29+
30+
fakepath = "./outputs/reconstruction/{}.png".format(os.path.basename(path).split('.')[0])
31+
string = tf.image.encode_png(x_hat)
32+
tf.io.write_file(fakepath, string)
33+
34+
35+
print('========================================================================\n')
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument('model_path')
40+
parser.add_argument('binary_path')
41+
42+
43+
args = parser.parse_args()
44+
45+
model = load_model(args)
46+
dtypes = [t.dtype for t in model.decompress.input_signature]
47+
48+
decompress(model, args)

lpips_weights/net-lin_alex_v0.1.pb

9.49 MB
Binary file not shown.

0 commit comments

Comments
 (0)