Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions pytm/json.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json

from .tm import TM
from .base import DataSet
from .boundary import Boundary
from .data import Data
from .dataflow import Dataflow
from .asset import Asset, Server, ExternalEntity, Lambda, LLM
from .datastore import Datastore
from .actor import Actor
from .process import Process, SetOfProcesses
from .enums import Action
from .enums import Action, Classification, Lifetime

_ELEMENT_CLASSES = {
"Asset": Asset,
Expand All @@ -24,36 +26,30 @@

def loads(s):
"""Load a TM object from a JSON string *s*."""
result = json.loads(s, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result
result = json.loads(s)
return _decode(result)


def load(fp):
"""Load a TM object from an open file containing JSON."""
result = json.load(fp, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result
result = json.load(fp)
return _decode(result)


def decode(data):
if "elements" not in data and "flows" not in data and "boundaries" not in data:
return data
def _decode(flat):
boundaries = _decode_boundaries(flat.pop("boundaries", []))
data = _decode_data(flat.pop("data", []))
elements = _decode_elements(flat.pop("elements", []), boundaries)
_decode_flows(flat.pop("flows", []), elements, data)

boundaries = decode_boundaries(data.pop("boundaries", []))
elements = decode_elements(data.pop("elements", []), boundaries)
decode_flows(data.pop("flows", []), elements)

if "name" not in data:
if "name" not in flat:
raise ValueError("name property missing for threat model")
if "onDuplicates" in data:
data["onDuplicates"] = Action(data["onDuplicates"])
return TM(data.pop("name"), **data)
if "onDuplicates" in flat:
flat["onDuplicates"] = Action(flat["onDuplicates"])
return TM(flat.pop("name"), **flat)


def decode_boundaries(flat):
def _decode_boundaries(flat):
boundaries = {}
refs = {}
for i, e in enumerate(flat):
Expand All @@ -74,7 +70,28 @@ def decode_boundaries(flat):
return boundaries


def decode_elements(flat, boundaries):
def _decode_data(flat):
data = {}
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
raise ValueError(f"name property missing in data {i}")

classification_name = e.pop("classification", None)
if classification_name:
e["classification"] = Classification[classification_name]

lifetime_name = e.pop("lifetime", None)
if lifetime_name:
e["lifetime"] = Lifetime[lifetime_name]

d = Data(name, **e)
data[name] = d

return data


def _decode_elements(flat, boundaries):
elements = {}
for i, e in enumerate(flat):
class_name = e.pop("__class__", "Asset")
Expand All @@ -96,7 +113,7 @@ def decode_elements(flat, boundaries):
return elements


def decode_flows(flat, elements):
def _decode_flows(flat, elements, data):
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
Expand All @@ -111,4 +128,13 @@ def decode_flows(flat, elements):
if e["sink"] not in elements:
raise ValueError(f"dataflow {name} references invalid sink {e['sink']}")
sink = elements[e.pop("sink")]

if "data" in e:
dataset = DataSet()
for data_name in e["data"]:
if data_name not in data:
raise ValueError(f"dataflow {name} references invalid data {data_name}")
dataset.add(data[data_name])
e["data"] = dataset

Dataflow(source, sink, name, **e)
10 changes: 10 additions & 0 deletions tests/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
"name": "Server/DB"
}
],
"data": [
{
"name": "Password",
"classification": "SECRET",
"lifetime": "LONG"
}
],
"elements": [
{
"__class__": "Actor",
Expand All @@ -32,6 +39,9 @@
"name": "Request",
"source": "User",
"sink": "Web Server",
"data": [
"Password"
],
"note": "bbb"
},
{
Expand Down
8 changes: 8 additions & 0 deletions tests/test_pytmfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ def test_json_loads(self):
]
assert [f.name for f in tm._flows] == ["Request", "Insert", "Select", "Response"]

assert [d.model_dump(include=["name", "classification", "lifetime"]) for d in tm._data] == [
{
"name": "Password",
"classification": Classification.SECRET,
"lifetime": Lifetime.LONG,
},
]

@pytest.mark.parametrize(
"class_name,expected_type",
[
Expand Down
Loading