Skip to content

Commit f8be997

Browse files
committed
update with generation testing
1 parent f2d0847 commit f8be997

File tree

5 files changed

+107
-12
lines changed

5 files changed

+107
-12
lines changed

.DS_Store

-6 KB
Binary file not shown.

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.python-version
2+
bin/
3+
gpt2.*
4+
lib/
5+
pyvenv.cfg
6+
.DS_Store

README.md

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ gpt2.mapping
6161
gpt2.xml
6262
```
6363

64-
## Checks
64+
## Tests
6565

66+
#### Local Machine
6667
To check if everything works fine run the script `run.py`. You should start seeing the outputs, the following is on the machine with following configuration:
6768
```
6869
MacBook Pro (13-inch, 2020, Four Thunderbolt 3 ports)
@@ -74,17 +75,39 @@ The performance results are as follows (`2x` boost):
7475
```
7576
----------------------------------------------------------------------
7677
Loading Pytorch model
77-
Pytorch inference in 0.59065s
78+
:: Pytorch inference in 0.59065s
7879
----------------------------------------------------------------------
7980
Creating Inference Engine...
8081
Loading network
8182
Loading IR to the plugin...
8283
exec_net: <openvino.inference_engine.ie_api.ExecutableNetwork object at 0x12c531fb0>
84+
:: OpenVino inference in 0.26206s
8385
----------------------------------------------------------------------
84-
OpenVino inference in 0.26206s
86+
```
87+
88+
In order to test generation capabilities you can pass `--g` flag and get the following results:
89+
```
90+
----------------------------------------------------------------------
91+
Loading Pytorch model
92+
Text shape: torch.Size([1, 127])
93+
:: Pytorch inference in 0.46476s
94+
----------------------------------------------------------------------
95+
Testing generation
96+
:: Pytorch generation took (40 steps): 17.663s
97+
----------------------------------------------------------------------
98+
Creating Inference Engine...
99+
Loading network
100+
Loading IR to the plugin...
101+
exec_net: <openvino.inference_engine.ie_api.ExecutableNetwork object at 0x130aaffb0>
102+
:: OpenVino inference in 0.23262s
103+
----------------------------------------------------------------------
104+
Testing generation
105+
:: OpenVino generation took (40 steps): 6.220s
85106
----------------------------------------------------------------------
86107
```
87108

109+
#### Cloud Server
110+
88111
When running on AWS `c5.12xlarge` and batching the data to `128` samples in a batch we see larger performance increase.
89112
```
90113
----------------------------------------------------------------------

requirements.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
certifi==2020.12.5
2+
chardet==4.0.0
3+
click==7.1.2
4+
filelock==3.0.12
5+
idna==2.10
6+
importlib-metadata==3.4.0
7+
joblib==1.0.1
8+
numpy==1.20.1
9+
packaging==20.9
10+
pyparsing==2.4.7
11+
regex==2020.11.13
12+
requests==2.25.1
13+
sacremoses==0.0.43
14+
six==1.15.0
15+
tokenizers==0.10.1
16+
torch==1.7.1
17+
tqdm==4.56.2
18+
transformers==4.3.2
19+
typing-extensions==3.7.4.3
20+
urllib3==1.26.3
21+
zipp==3.4.0

run.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,63 @@
1313
"""
1414
import os
1515
import time
16+
import numpy as np
1617
from argparse import ArgumentParser
1718

1819
import torch
1920
from transformers import AutoTokenizer, AutoModelForCausalLM
2021

2122
from openvino.inference_engine import IECore
2223

24+
def generate_greedy_pytorch(tokens, model, n):
25+
complete_seq = tokens.permute((1, 0)).tolist()
26+
for _ in range(n):
27+
out = model(tokens)
28+
next_tokens = torch.argmax(out.logits[:, -1], dim = -1).unsqueeze(1)
29+
tokens = torch.cat([tokens, next_tokens], dim=-1)
30+
tokens = tokens[:, 1:]
31+
complete_seq.extend(next_tokens.tolist())
32+
return np.array(complete_seq).T.tolist()
33+
34+
35+
def generate_greedy_openvino(tokens, exec_net, n, logits_dict_key = "2859"):
36+
complete_seq = tokens.T.tolist()
37+
for _ in range(n):
38+
out = exec_net.infer(inputs={"0": inputs})[logits_dict_key]
39+
next_tokens = np.argmax(out[:, -1], axis=-1).reshape(-1, 1)
40+
tokens = np.hstack((tokens, next_tokens))
41+
tokens = tokens[:, 1:]
42+
complete_seq.extend(next_tokens.tolist())
43+
return np.array(complete_seq).T.tolist()
44+
45+
2346
if __name__ == '__main__':
2447
parser = ArgumentParser()
25-
parser.add_argument("--model", help="Path to an .xml file with a trained model.", default = "gpt2.xml", type=str)
48+
parser.add_argument("--model", help="Path to an .xml file with a trained model.", default = "./gpt2.xml", type=str)
49+
parser.add_argument("--g", help="if set model will also test generation", action = "store_true", default = False)
50+
args = parser.parse_args()
2651

2752
print("-"*70)
2853
print("Loading Pytorch model")
2954
tokenizer = AutoTokenizer.from_pretrained("gpt2")
3055
model = AutoModelForCausalLM.from_pretrained("gpt2")
3156
with open("text.en", "r") as f:
3257
text = f.read()
33-
input_encoder = tokenizer([text for _ in range(100)], return_tensors="pt")
34-
print(":: -->", input_encoder["input_ids"].size())
58+
input_encoder = tokenizer([text + tokenizer.eos_token for _ in range(1)], return_tensors="pt")
59+
print("Text shape:", input_encoder["input_ids"].size())
3560

3661
st = time.time()
3762
model(input_encoder[ "input_ids"])
38-
print(f"Pytorch inference in {time.time() - st:.5f}s")
39-
del model, tokenizer
63+
print(f":: Pytorch inference in {time.time() - st:.5f}s")
64+
if args.g:
65+
print("-"*70)
66+
print("Testing generation")
67+
st = time.time()
68+
out = generate_greedy_pytorch(input_encoder["input_ids"], model, n = 40)
69+
out = tokenizer.decode(out[0])
70+
print(f":: Pytorch generation took (40 steps): {time.time() - st:.3f}s")
71+
del model
4072

41-
args = parser.parse_args()
4273
print("-"*70)
4374
model_xml = args.model
4475
model_bin = os.path.splitext(model_xml)[0] + ".bin"
@@ -54,7 +85,6 @@
5485
print("Loading IR to the plugin...")
5586
exec_net = ie.load_network(network=net, device_name="CPU", num_requests=2)
5687
print(f"exec_net: {exec_net}")
57-
print("-"*70)
5888

5989
# this is a bit tricky. So the input to the model is the input from ONNX graph
6090
# IECore makes a networkX graph of the "computation graph" and when we run .infer
@@ -63,6 +93,21 @@
6393
# suspect. Happy Hunting!
6494
inputs = input_encoder["input_ids"].tolist()
6595
st = time.time()
66-
out = exec_net.infer(inputs={"0": [1 for _ in range(127)]})
67-
print(f"OpenVino inference in {time.time() - st:.5f}s")
96+
out = exec_net.infer(inputs={"0": inputs}, )
97+
98+
# now this out is a dictionary and has a lot of outputs so you will need to manually
99+
# determine which is the output that you want by checking the correct shape
100+
# for k in list(out.keys()):
101+
# print(k, "-->", out[k].shape)
102+
103+
print(f":: OpenVino inference in {time.time() - st:.5f}s")
104+
105+
if args.g:
106+
print("-"*70)
107+
print("Testing generation")
108+
st = time.time()
109+
out = generate_greedy_openvino(input_encoder["input_ids"].numpy(), exec_net, n=40)
110+
out = tokenizer.decode(out[0])
111+
print(f":: OpenVino generation took (40 steps): {time.time() - st:.3f}s")
112+
68113
print("-"*70)

0 commit comments

Comments
 (0)