|
15 | 15 | "metadata": {},
|
16 | 16 | "outputs": [],
|
17 | 17 | "source": [
|
18 |
| - "from nestedtensor import torch\n", |
| 18 | + "import torch\n", |
| 19 | + "import nestedtensor\n", |
19 | 20 | "from IPython.display import Markdown, display\n",
|
20 | 21 | "\n",
|
21 | 22 | "def print_eval(s):\n",
|
|
54 | 55 | "text": [
|
55 | 56 | "nested_tensor([\n",
|
56 | 57 | "\t[\n",
|
57 |
| - "\t\ttensor([[0.8264, 0.2200, 0.4197],\n", |
58 |
| - "\t\t [0.6789, 0.7460, 0.1694]]),\n", |
59 |
| - "\t\ttensor([[0.7467, 0.8433, 0.6429, 0.9890, 0.0170],\n", |
60 |
| - "\t\t [0.6297, 0.3899, 0.7025, 0.0812, 0.9585],\n", |
61 |
| - "\t\t [0.1113, 0.4260, 0.4245, 0.7971, 0.7910],\n", |
62 |
| - "\t\t [0.7077, 0.6765, 0.0228, 0.5461, 0.4095]])\n", |
| 58 | + "\t\ttensor([[0.1525, 0.9457, 0.8438],\n", |
| 59 | + "\t\t [0.6784, 0.9376, 0.5344]]),\n", |
| 60 | + "\t\ttensor([[0.5654, 0.6054, 0.2726, 0.8868, 0.3417],\n", |
| 61 | + "\t\t [0.1225, 0.4104, 0.9022, 0.6978, 0.2081],\n", |
| 62 | + "\t\t [0.5641, 0.2983, 0.7589, 0.5495, 0.1304],\n", |
| 63 | + "\t\t [0.1999, 0.3803, 0.0336, 0.4855, 0.9838]])\n", |
63 | 64 | "\t],\n",
|
64 | 65 | "\t[\n",
|
65 |
| - "\t\ttensor([[0.0660, 0.9756]])\n", |
| 66 | + "\t\ttensor([[0.8105, 0.6778]])\n", |
66 | 67 | "\t]\n",
|
67 | 68 | "])\n",
|
68 | 69 | "\n"
|
69 | 70 | ]
|
70 | 71 | }
|
71 | 72 | ],
|
72 | 73 | "source": [
|
73 |
| - "nt = torch.nested_tensor(\n", |
| 74 | + "nt = nestedtensor.nested_tensor(\n", |
74 | 75 | " [\n",
|
75 | 76 | " [\n",
|
76 | 77 | " torch.rand(2, 3),\n",
|
|
227 | 228 | " [3, 3],\n",
|
228 | 229 | " [4, 4],\n",
|
229 | 230 | " [5, 5]])\n",
|
230 |
| - "nt2 = torch.nested_tensor([[a],[b]])\n", |
| 231 | + "nt2 = nestedtensor.nested_tensor([[a],[b]])\n", |
231 | 232 | "print_eval(\"nt2.nested_dim()\")\n",
|
232 | 233 | "print_eval(\"nt2.tensor_dim()\")\n",
|
233 | 234 | "print_eval(\"nt2.dim()\")"
|
|
281 | 282 | "text": [
|
282 | 283 | "nested_tensor([\n",
|
283 | 284 | "\t[\n",
|
284 |
| - "\t\ttensor([[0.8264, 0.2200, 0.4197],\n", |
285 |
| - "\t\t [0.6789, 0.7460, 0.1694]]),\n", |
286 |
| - "\t\ttensor([[0.7467, 0.8433, 0.6429, 0.9890, 0.0170],\n", |
287 |
| - "\t\t [0.6297, 0.3899, 0.7025, 0.0812, 0.9585],\n", |
288 |
| - "\t\t [0.1113, 0.4260, 0.4245, 0.7971, 0.7910],\n", |
289 |
| - "\t\t [0.7077, 0.6765, 0.0228, 0.5461, 0.4095]])\n", |
| 285 | + "\t\ttensor([[0.1525, 0.9457, 0.8438],\n", |
| 286 | + "\t\t [0.6784, 0.9376, 0.5344]]),\n", |
| 287 | + "\t\ttensor([[0.5654, 0.6054, 0.2726, 0.8868, 0.3417],\n", |
| 288 | + "\t\t [0.1225, 0.4104, 0.9022, 0.6978, 0.2081],\n", |
| 289 | + "\t\t [0.5641, 0.2983, 0.7589, 0.5495, 0.1304],\n", |
| 290 | + "\t\t [0.1999, 0.3803, 0.0336, 0.4855, 0.9838]])\n", |
290 | 291 | "\t],\n",
|
291 | 292 | "\t[\n",
|
292 |
| - "\t\ttensor([[0.0660, 0.9756]])\n", |
| 293 | + "\t\ttensor([[0.8105, 0.6778]])\n", |
293 | 294 | "\t]\n",
|
294 | 295 | "])\n",
|
295 | 296 | "\n"
|
|
621 | 622 | {
|
622 | 623 | "data": {
|
623 | 624 | "text/markdown": [
|
624 |
| - "**<span style='color:darkred'>$ torch.nested_tensor_from_tensor_mask(tensor, mask)</span>**" |
| 625 | + "**<span style='color:darkred'>$ nestedtensor.nested_tensor_from_tensor_mask(tensor, mask)</span>**" |
625 | 626 | ],
|
626 | 627 | "text/plain": [
|
627 | 628 | "<IPython.core.display.Markdown object>"
|
|
646 | 647 | {
|
647 | 648 | "data": {
|
648 | 649 | "text/markdown": [
|
649 |
| - "**<span style='color:darkred'>$ torch.nested_tensor_from_padded_tensor(tensor, padding=0)</span>**" |
| 650 | + "**<span style='color:darkred'>$ nestedtensor.nested_tensor_from_padded_tensor(tensor, padding=0)</span>**" |
650 | 651 | ],
|
651 | 652 | "text/plain": [
|
652 | 653 | "<IPython.core.display.Markdown object>"
|
|
688 | 689 | " [ True, True, True, True]]])\n",
|
689 | 690 | "print_eval(\"tensor\")\n",
|
690 | 691 | "print_eval(\"mask\")\n",
|
691 |
| - "nt2 = torch.nested_tensor_from_tensor_mask(tensor, mask)\n", |
692 |
| - "print_eval(\"torch.nested_tensor_from_tensor_mask(tensor, mask)\")\n", |
693 |
| - "print_eval(\"torch.nested_tensor_from_padded_tensor(tensor, padding=0)\")" |
| 692 | + "nt2 = nestedtensor.nested_tensor_from_tensor_mask(tensor, mask)\n", |
| 693 | + "print_eval(\"nestedtensor.nested_tensor_from_tensor_mask(tensor, mask)\")\n", |
| 694 | + "print_eval(\"nestedtensor.nested_tensor_from_padded_tensor(tensor, padding=0)\")" |
694 | 695 | ]
|
695 | 696 | },
|
696 | 697 | {
|
|
795 | 796 | "output_type": "stream",
|
796 | 797 | "text": [
|
797 | 798 | "nested_tensor([\n",
|
798 |
| - "\ttensor([[0.8264, 0.2200, 0.4197],\n", |
799 |
| - "\t [0.6789, 0.7460, 0.1694]]),\n", |
800 |
| - "\ttensor([[0.7467, 0.8433, 0.6429, 0.9890, 0.0170],\n", |
801 |
| - "\t [0.6297, 0.3899, 0.7025, 0.0812, 0.9585],\n", |
802 |
| - "\t [0.1113, 0.4260, 0.4245, 0.7971, 0.7910],\n", |
803 |
| - "\t [0.7077, 0.6765, 0.0228, 0.5461, 0.4095]])\n", |
| 799 | + "\ttensor([[0.1525, 0.9457, 0.8438],\n", |
| 800 | + "\t [0.6784, 0.9376, 0.5344]]),\n", |
| 801 | + "\ttensor([[0.5654, 0.6054, 0.2726, 0.8868, 0.3417],\n", |
| 802 | + "\t [0.1225, 0.4104, 0.9022, 0.6978, 0.2081],\n", |
| 803 | + "\t [0.5641, 0.2983, 0.7589, 0.5495, 0.1304],\n", |
| 804 | + "\t [0.1999, 0.3803, 0.0336, 0.4855, 0.9838]])\n", |
804 | 805 | "])\n",
|
805 | 806 | "\n"
|
806 | 807 | ]
|
|
822 | 823 | "output_type": "stream",
|
823 | 824 | "text": [
|
824 | 825 | "nested_tensor([\n",
|
825 |
| - "\ttensor([[0.0660, 0.9756]])\n", |
| 826 | + "\ttensor([[0.8105, 0.6778]])\n", |
826 | 827 | "])\n",
|
827 | 828 | "\n"
|
828 | 829 | ]
|
|
857 | 858 | "text": [
|
858 | 859 | "nested_tensor([\n",
|
859 | 860 | "\t[\n",
|
860 |
| - "\t\ttensor([[0.6776, 0.9759, 0.9132],\n", |
861 |
| - "\t\t [0.7783, 0.7344, 0.9857]]),\n", |
862 |
| - "\t\ttensor([[0.7467, 0.8433, 0.6429, 0.9890, 0.0170],\n", |
863 |
| - "\t\t [0.6297, 0.3899, 0.7025, 0.0812, 0.9585],\n", |
864 |
| - "\t\t [0.1113, 0.4260, 0.4245, 0.7971, 0.7910],\n", |
865 |
| - "\t\t [0.7077, 0.6765, 0.0228, 0.5461, 0.4095]])\n", |
| 861 | + "\t\ttensor([[0.9884, 0.5852, 0.6646],\n", |
| 862 | + "\t\t [0.7786, 0.5917, 0.8606]]),\n", |
| 863 | + "\t\ttensor([[0.5654, 0.6054, 0.2726, 0.8868, 0.3417],\n", |
| 864 | + "\t\t [0.1225, 0.4104, 0.9022, 0.6978, 0.2081],\n", |
| 865 | + "\t\t [0.5641, 0.2983, 0.7589, 0.5495, 0.1304],\n", |
| 866 | + "\t\t [0.1999, 0.3803, 0.0336, 0.4855, 0.9838]])\n", |
866 | 867 | "\t],\n",
|
867 | 868 | "\t[\n",
|
868 |
| - "\t\ttensor([[0.0660, 0.9756]])\n", |
| 869 | + "\t\ttensor([[0.8105, 0.6778]])\n", |
869 | 870 | "\t]\n",
|
870 | 871 | "])\n",
|
871 | 872 | "\n"
|
|
0 commit comments