|
15 | 15 | */
|
16 | 16 | package org.tensorflow;
|
17 | 17 |
|
| 18 | +import java.util.Map; |
18 | 19 | import java.util.Set;
|
19 | 20 | import org.tensorflow.ndarray.Shape;
|
20 | 21 | import org.tensorflow.proto.framework.DataType;
|
@@ -145,15 +146,50 @@ public Set<String> outputNames() {
|
145 | 146 | return signatureDef.getOutputsMap().keySet();
|
146 | 147 | }
|
147 | 148 |
|
| 149 | + @Override |
| 150 | + public String toString() { |
| 151 | + StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n"); |
| 152 | + if (!methodName().isEmpty()) { |
| 153 | + strBuilder.append("\tMethod: \"").append(methodName()).append("\"\n"); |
| 154 | + } |
| 155 | + if (signatureDef.getInputsCount() > 0) { |
| 156 | + strBuilder.append("\tInputs:\n"); |
| 157 | + printTensorInfo(signatureDef.getInputsMap(), strBuilder); |
| 158 | + } |
| 159 | + if (signatureDef.getOutputsCount() > 0) { |
| 160 | + strBuilder.append("\tOutputs:\n"); |
| 161 | + printTensorInfo(signatureDef.getOutputsMap(), strBuilder); |
| 162 | + } |
| 163 | + return strBuilder.toString(); |
| 164 | + } |
| 165 | + |
| 166 | + Signature(String key, SignatureDef signatureDef) { |
| 167 | + this.key = key; |
| 168 | + this.signatureDef = signatureDef; |
| 169 | + } |
| 170 | + |
148 | 171 | SignatureDef asSignatureDef() {
|
149 | 172 | return signatureDef;
|
150 | 173 | }
|
151 | 174 |
|
152 | 175 | private final String key;
|
153 | 176 | private final SignatureDef signatureDef;
|
154 | 177 |
|
155 |
| - Signature(String key, SignatureDef signatureDef) { |
156 |
| - this.key = key; |
157 |
| - this.signatureDef = signatureDef; |
| 178 | + private static void printTensorInfo(Map<String, TensorInfo> tensorMap, StringBuilder strBuilder) { |
| 179 | + tensorMap.forEach((key, tensorInfo) -> { |
| 180 | + strBuilder.append("\t\t\"") |
| 181 | + .append(key) |
| 182 | + .append("\":") |
| 183 | + .append(" dtype=") |
| 184 | + .append(tensorInfo.getDtype().name()) |
| 185 | + .append(", shape=("); |
| 186 | + for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { |
| 187 | + strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); |
| 188 | + if (i < tensorInfo.getTensorShape().getDimCount() - 1) { |
| 189 | + strBuilder.append(", "); |
| 190 | + } |
| 191 | + } |
| 192 | + strBuilder.append(")\n"); |
| 193 | + }); |
158 | 194 | }
|
159 | 195 | }
|
0 commit comments