diff --git a/integration-test/test/test_plutus.py b/integration-test/test/test_plutus.py index c57119f0..acb554c8 100644 --- a/integration-test/test/test_plutus.py +++ b/integration-test/test/test_plutus.py @@ -26,7 +26,7 @@ def test_plutus_v1(self): builder = TransactionBuilder(self.chain_context) builder.add_input_address(giver_address) - datum = PlutusData() # A Unit type "()" in Haskell + datum = Unit() # A Unit type "()" in Haskell builder.add_output( TransactionOutput(script_address, 50000000, datum_hash=datum_hash(datum)) ) diff --git a/pycardano/plutus.py b/pycardano/plutus.py index 2fc7f0a4..c8562874 100644 --- a/pycardano/plutus.py +++ b/pycardano/plutus.py @@ -6,7 +6,8 @@ import json from dataclasses import dataclass, field, fields from enum import Enum -from typing import Any, ClassVar, Optional, Type, Union +from hashlib import sha256 +from typing import Any, Optional, Type, Union import cbor2 from cbor2 import CBORTag @@ -44,9 +45,16 @@ "datum_hash", "plutus_script_hash", "script_hash", + "Unit", ] +# taken from https://stackoverflow.com/a/13624858 +class classproperty(property): + def __get__(self, owner_self, owner_cls): + return self.fget(owner_cls) + + class CostModels(DictCBORSerializable): KEY_TYPE = int VALUE_TYPE = dict @@ -460,9 +468,25 @@ class will reduce the complexity of serialization and deserialization tremendous >>> assert test == Test.from_cbor("d87a9f187b43333231ff") """ - CONSTR_ID: ClassVar[int] = 0 - """Constructor ID of this plutus data. - It is primarily used by Plutus core to reconstruct a data structure from serialized CBOR bytes.""" + @classproperty + def CONSTR_ID(cls): + """ + Constructor ID of this plutus data. + It is primarily used by Plutus core to reconstruct a data structure from serialized CBOR bytes. + The default implementation is an almost unique, deterministic constructor ID in the range 1 - 2^32 based + on class attributes, types and class name. + """ + k = f"_CONSTR_ID_{cls.__name__}" + if not hasattr(cls, k): + det_string = ( + cls.__name__ + + "*" + + "*".join([f"{f.name}~{f.type}" for f in fields(cls)]) + ) + det_hash = sha256(det_string.encode("utf8")).hexdigest() + setattr(cls, k, int(det_hash, 16) % 2**32) + + return getattr(cls, k) def __post_init__(self): valid_types = (PlutusData, dict, IndefiniteList, int, bytes) @@ -820,3 +844,10 @@ def script_hash(script: ScriptType) -> ScriptHash: ) else: raise TypeError(f"Unexpected script type: {type(script)}") + + +@dataclass +class Unit(PlutusData): + """The default "Unit type" with a 0 constructor ID""" + + CONSTR_ID = 0 diff --git a/test/pycardano/test_plutus.py b/test/pycardano/test_plutus.py index 35cc9ff0..f2aba0f9 100644 --- a/test/pycardano/test_plutus.py +++ b/test/pycardano/test_plutus.py @@ -1,5 +1,7 @@ import copy -import unittest +import subprocess +import sys +import tempfile from dataclasses import dataclass from test.pycardano.util import check_two_way_cbor from typing import Dict, List, Union @@ -7,7 +9,7 @@ import pytest from cbor2 import CBORTag -from pycardano.exception import DeserializeException, SerializeException +from pycardano.exception import DeserializeException from pycardano.plutus import ( COST_MODELS, ExecutionUnits, @@ -51,6 +53,7 @@ class DictTest(PlutusData): @dataclass class ListTest(PlutusData): + CONSTR_ID = 0 a: List[LargestTest] @@ -204,7 +207,7 @@ def test_plutus_data_from_json_wrong_data_structure_type(): def test_plutus_data_hash(): assert ( bytes.fromhex( - "923918e403bf43c34b4ef6b48eb2ee04babed17320d8d1b9ff9ad086e86f44ec" + "19d31e4f3aa9b03ad93b64c8dd2cc822d247c21e2c22762b7b08e6cadfeddb47" ) == PlutusData().hash().payload ) @@ -316,3 +319,80 @@ def test_clone_plutus_data(): my_vesting.deadline = 1643235300001 assert cloned_vesting != my_vesting + + +def test_unique_constr_ids(): + @dataclass + class A(PlutusData): + pass + + @dataclass + class B(PlutusData): + pass + + assert ( + A.CONSTR_ID != B.CONSTR_ID + ), "Different classes (different names) have same default constructor ID" + B_tmp = B + + @dataclass + class B(PlutusData): + a: int + b: bytes + + assert ( + B_tmp.CONSTR_ID != B.CONSTR_ID + ), "Different classes (different fields) have same default constructor ID" + + B_tmp = B + + @dataclass + class B(PlutusData): + a: bytes + b: bytes + + assert ( + B_tmp.CONSTR_ID != B.CONSTR_ID + ), "Different classes (different field types) have same default constructor ID" + + +def test_deterministic_constr_ids_local(): + @dataclass + class A(PlutusData): + a: int + b: bytes + + A_tmp = A + + @dataclass + class A(PlutusData): + a: int + b: bytes + + assert ( + A_tmp.CONSTR_ID == A.CONSTR_ID + ), "Same class has different default constructor ID" + + +def test_deterministic_constr_ids_global(): + code = """ +from dataclasses import dataclass +from pycardano import PlutusData + +@dataclass +class A(PlutusData): + a: int + b: bytes + +print(A.CONSTR_ID) +""" + tmpfile = tempfile.TemporaryFile() + tmpfile.write(code.encode("utf8")) + tmpfile.seek(0) + res = subprocess.run([sys.executable], stdin=tmpfile, capture_output=True).stdout + tmpfile.seek(0) + res2 = subprocess.run([sys.executable], stdin=tmpfile, capture_output=True).stdout + + assert ( + res == res2 + ), "Same class has different default constructor id in two consecutive runs" diff --git a/test/pycardano/test_util.py b/test/pycardano/test_util.py index 50e033b7..118c8e13 100644 --- a/test/pycardano/test_util.py +++ b/test/pycardano/test_util.py @@ -1,7 +1,7 @@ from test.pycardano.util import chain_context from pycardano.hash import SCRIPT_HASH_SIZE, ScriptDataHash -from pycardano.plutus import ExecutionUnits, PlutusData, Redeemer, RedeemerTag +from pycardano.plutus import ExecutionUnits, PlutusData, Redeemer, RedeemerTag, Unit from pycardano.transaction import Value from pycardano.utils import min_lovelace_pre_alonzo, script_data_hash @@ -145,7 +145,7 @@ def test_min_lovelace_multi_asset_9(self, chain_context): def test_script_data_hash(): - unit = PlutusData() + unit = Unit() redeemers = [Redeemer(unit, ExecutionUnits(1000000, 1000000))] redeemers[0].tag = RedeemerTag.SPEND assert ScriptDataHash.from_primitive( @@ -154,14 +154,14 @@ def test_script_data_hash(): def test_script_data_hash_datum_only(): - unit = PlutusData() + unit = Unit() assert ScriptDataHash.from_primitive( "2f50ea2546f8ce020ca45bfcf2abeb02ff18af2283466f888ae489184b3d2d39" ) == script_data_hash(redeemers=[], datums=[unit]) def test_script_data_hash_redeemer_only(): - unit = PlutusData() + unit = Unit() redeemers = [] assert ScriptDataHash.from_primitive( "a88fe2947b8d45d1f8b798e52174202579ecf847b8f17038c7398103df2d27b0"