diff --git a/pytm/json.py b/pytm/json.py index e9a1deb..d202b71 100644 --- a/pytm/json.py +++ b/pytm/json.py @@ -1,13 +1,15 @@ import json from .tm import TM +from .base import DataSet from .boundary import Boundary +from .data import Data from .dataflow import Dataflow from .asset import Asset, Server, ExternalEntity, Lambda, LLM from .datastore import Datastore from .actor import Actor from .process import Process, SetOfProcesses -from .enums import Action +from .enums import Action, Classification, Lifetime _ELEMENT_CLASSES = { "Asset": Asset, @@ -24,36 +26,30 @@ def loads(s): """Load a TM object from a JSON string *s*.""" - result = json.loads(s, object_hook=decode) - if not isinstance(result, TM): - raise ValueError("Failed to decode JSON input as TM") - return result + result = json.loads(s) + return _decode(result) def load(fp): """Load a TM object from an open file containing JSON.""" - result = json.load(fp, object_hook=decode) - if not isinstance(result, TM): - raise ValueError("Failed to decode JSON input as TM") - return result + result = json.load(fp) + return _decode(result) -def decode(data): - if "elements" not in data and "flows" not in data and "boundaries" not in data: - return data +def _decode(flat): + boundaries = _decode_boundaries(flat.pop("boundaries", [])) + data = _decode_data(flat.pop("data", [])) + elements = _decode_elements(flat.pop("elements", []), boundaries) + _decode_flows(flat.pop("flows", []), elements, data) - boundaries = decode_boundaries(data.pop("boundaries", [])) - elements = decode_elements(data.pop("elements", []), boundaries) - decode_flows(data.pop("flows", []), elements) - - if "name" not in data: + if "name" not in flat: raise ValueError("name property missing for threat model") - if "onDuplicates" in data: - data["onDuplicates"] = Action(data["onDuplicates"]) - return TM(data.pop("name"), **data) + if "onDuplicates" in flat: + flat["onDuplicates"] = Action(flat["onDuplicates"]) + return TM(flat.pop("name"), **flat) -def decode_boundaries(flat): +def _decode_boundaries(flat): boundaries = {} refs = {} for i, e in enumerate(flat): @@ -74,7 +70,28 @@ def decode_boundaries(flat): return boundaries -def decode_elements(flat, boundaries): +def _decode_data(flat): + data = {} + for i, e in enumerate(flat): + name = e.pop("name", None) + if name is None: + raise ValueError(f"name property missing in data {i}") + + classification_name = e.pop("classification", None) + if classification_name: + e["classification"] = Classification[classification_name] + + lifetime_name = e.pop("lifetime", None) + if lifetime_name: + e["lifetime"] = Lifetime[lifetime_name] + + d = Data(name, **e) + data[name] = d + + return data + + +def _decode_elements(flat, boundaries): elements = {} for i, e in enumerate(flat): class_name = e.pop("__class__", "Asset") @@ -96,7 +113,7 @@ def decode_elements(flat, boundaries): return elements -def decode_flows(flat, elements): +def _decode_flows(flat, elements, data): for i, e in enumerate(flat): name = e.pop("name", None) if name is None: @@ -111,4 +128,13 @@ def decode_flows(flat, elements): if e["sink"] not in elements: raise ValueError(f"dataflow {name} references invalid sink {e['sink']}") sink = elements[e.pop("sink")] + + if "data" in e: + dataset = DataSet() + for data_name in e["data"]: + if data_name not in data: + raise ValueError(f"dataflow {name} references invalid data {data_name}") + dataset.add(data[data_name]) + e["data"] = dataset + Dataflow(source, sink, name, **e) diff --git a/tests/input.json b/tests/input.json index 0cdcff8..798a284 100644 --- a/tests/input.json +++ b/tests/input.json @@ -11,6 +11,13 @@ "name": "Server/DB" } ], + "data": [ + { + "name": "Password", + "classification": "SECRET", + "lifetime": "LONG" + } + ], "elements": [ { "__class__": "Actor", @@ -32,6 +39,9 @@ "name": "Request", "source": "User", "sink": "Web Server", + "data": [ + "Password" + ], "note": "bbb" }, { diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index c0696be..00dac53 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -426,6 +426,14 @@ def test_json_loads(self): ] assert [f.name for f in tm._flows] == ["Request", "Insert", "Select", "Response"] + assert [d.model_dump(include=["name", "classification", "lifetime"]) for d in tm._data] == [ + { + "name": "Password", + "classification": Classification.SECRET, + "lifetime": Lifetime.LONG, + }, + ] + @pytest.mark.parametrize( "class_name,expected_type", [