Skip to content

Commit abc790d

Browse files
authored
Merge pull request #108 from ourstudio-se/107-050-quick-fixes
107 050 quick fixes
2 parents 479f1f1 + ac700b3 commit abc790d

4 files changed

Lines changed: 208 additions & 50 deletions

File tree

puan/logic/plog/__init__.py

Lines changed: 93 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import puan
1414
import puan.ndarray as pnd
1515
import puan_rspy as pr
16-
import dictdiffer
1716
import more_itertools
1817
from dataclasses import dataclass
1918
from collections import Counter
@@ -107,10 +106,10 @@ def __init__(self, value: int, propositions: typing.List[typing.Union[str, puan.
107106
raise Exception(f"`sign` of AtLeast proposition must be either -1 or 1, got: {sign}")
108107

109108

110-
propositions_list = list(propositions)
111-
if propositions is None or len(propositions_list) == 0:
109+
if propositions is None:
112110
raise Exception("Sub propositions cannot be `None`")
113111

112+
propositions_list = list(propositions)
114113
self.propositions = sorted(
115114
itertools.chain(
116115
filter(
@@ -1074,8 +1073,9 @@ def from_short(short: typing.Tuple[str, puan.Sign, typing.List[str], int, typing
10741073
def solve(
10751074
self,
10761075
objectives: typing.List[typing.Dict[typing.Union[str, puan.variable], int]],
1077-
solver: typing.Callable[[pnd.ge_polyhedron, typing.Dict[str, int]], typing.Iterable[typing.Tuple[np.ndarray, int, int]]] = None,
1076+
solver: typing.Callable[[pnd.ge_polyhedron, typing.Iterable[np.ndarray]], typing.Iterable[typing.Tuple[typing.Optional[np.ndarray], typing.Optional[int], int]]] = None,
10781077
try_reduce_before: bool = False,
1078+
include_virtual_variables: bool = False,
10791079
) -> itertools.starmap:
10801080

10811081
"""
@@ -1109,6 +1109,9 @@ def solve(
11091109
If true, then methods will be applied to try and reduce size of this model before
11101110
running solve function.
11111111
1112+
include_virtual_variables : bool = False
1113+
If true, the virtual/artificial variables that has automatically been generated creating the model will be included in solutions
1114+
11121115
Examples
11131116
--------
11141117
>>> dummy_solver = lambda x,y: list(map(lambda v: (v, 0, 5), y))
@@ -1132,35 +1135,57 @@ def solve(
11321135
if solver is None:
11331136
pyrs_theory, variable_id_map = self._to_pyrs_theory()
11341137
id_map = dict(variable_id_map.values())
1135-
return list(
1136-
itertools.starmap(
1137-
lambda solution, objective_value, status_code: (
1138-
dict(
1139-
itertools.starmap(
1140-
lambda k,v: (id_map[k].id, v),
1141-
solution.items()
1142-
)
1143-
),
1144-
objective_value, status_code
1145-
),
1146-
pyrs_theory.solve(
1147-
list(
1148-
map(
1149-
lambda objective: dict(
1150-
zip(
1151-
map(
1152-
lambda k: variable_id_map[k][0],
1153-
objective,
1138+
return itertools.starmap(
1139+
lambda solution, objective_value, status_code: (
1140+
dict(
1141+
itertools.starmap(
1142+
lambda k,v: (id_map[k].id, v),
1143+
filter(
1144+
maz.ifttt(
1145+
# If is puan.variable
1146+
lambda x: issubclass(id_map[x[0]].__class__, puan.variable),
1147+
1148+
# then include it
1149+
lambda _: True,
1150+
1151+
# else if variable id is generated
1152+
maz.ifttt(
1153+
maz.compose(
1154+
operator.attrgetter("generated_id"),
1155+
id_map.get,
1156+
operator.itemgetter(0),
11541157
),
1155-
objective.values(),
1158+
1159+
# then also check that we should include those variables
1160+
lambda _: include_virtual_variables,
1161+
1162+
# else include it
1163+
lambda _: True
11561164
)
11571165
),
1158-
objectives,
1166+
solution.items()
1167+
)
1168+
)
1169+
),
1170+
objective_value, status_code
1171+
),
1172+
pyrs_theory.solve(
1173+
list(
1174+
map(
1175+
lambda objective: dict(
1176+
zip(
1177+
map(
1178+
lambda k: variable_id_map[k][0],
1179+
objective,
1180+
),
1181+
objective.values(),
1182+
)
11591183
),
1184+
objectives,
11601185
),
1161-
False,
11621186
),
1163-
)
1187+
False,
1188+
),
11641189
)
11651190
else:
11661191
polyhedron = self.to_ge_polyhedron(
@@ -1176,13 +1201,36 @@ def solve(
11761201
return itertools.starmap(
11771202
lambda solution, objective_value, status_code: (
11781203
dict(
1179-
zip(
1180-
map(
1181-
operator.attrgetter("id"),
1182-
polyhedron.A.variables
1183-
),
1184-
solution
1185-
)
1204+
map(
1205+
lambda x: (x[0].id, x[1]),
1206+
filter(
1207+
maz.ifttt(
1208+
# if variable is a puan.variable
1209+
lambda x: issubclass(x[0].__class__, puan.variable),
1210+
1211+
# then keep it
1212+
lambda _: True,
1213+
1214+
# else if variable id is generated
1215+
maz.ifttt(
1216+
maz.compose(
1217+
operator.attrgetter("generated_id"),
1218+
operator.itemgetter(0),
1219+
),
1220+
1221+
# then check also that we should include those variables
1222+
lambda _: include_virtual_variables,
1223+
1224+
# else include it
1225+
lambda _: True
1226+
)
1227+
),
1228+
zip(
1229+
polyhedron.A.variables,
1230+
solution
1231+
)
1232+
)
1233+
),
11861234
) if solution is not None else {},
11871235
objective_value, status_code
11881236
),
@@ -1784,8 +1832,8 @@ def __init__(self, *propositions, variable: typing.Union[str, puan.variable] = N
17841832
variable=variable,
17851833
)
17861834

1787-
@staticmethod
1788-
def from_json(data: dict, class_map) -> "Xor":
1835+
@classmethod
1836+
def from_json(cls, data: dict, class_map) -> "Xor":
17891837
"""
17901838
Convert from JSON data to a proposition.
17911839
@@ -1794,13 +1842,13 @@ def from_json(data: dict, class_map) -> "Xor":
17941842
out : :class:`Xor`
17951843
"""
17961844
propositions = data.get('propositions', [])
1797-
return Xor(
1845+
return cls(
17981846
*map(functools.partial(from_json, class_map=class_map), propositions),
17991847
variable=data.get('id', None)
18001848
)
18011849

1802-
@staticmethod
1803-
def from_list(propositions: typing.List[typing.Union["AtLeast", puan.variable]], variable: typing.Union[str, puan.variable] = None) -> "Xor":
1850+
@classmethod
1851+
def from_list(cls, propositions: typing.List[typing.Union["AtLeast", puan.variable]], variable: typing.Union[str, puan.variable] = None) -> "Xor":
18041852

18051853
"""
18061854
Convert from list of propositions to an object of this proposition class.
@@ -1816,7 +1864,7 @@ def from_list(propositions: typing.List[typing.Union["AtLeast", puan.variable]],
18161864
out : :class:`Xor`
18171865
"""
18181866

1819-
return Xor(*propositions, variable=variable)
1867+
return cls(*propositions, variable=variable)
18201868

18211869
def to_json(self) -> typing.Dict[str, typing.Any]:
18221870

@@ -1841,6 +1889,10 @@ def to_json(self) -> typing.Dict[str, typing.Any]:
18411889
return d
18421890

18431891

1892+
class ExactlyOne(Xor):
1893+
pass
1894+
1895+
18441896
class Not():
18451897

18461898
"""
@@ -1971,7 +2023,7 @@ def to_json(self) -> typing.Dict[str, typing.Any]:
19712023
d['id'] = self.id
19722024
return d
19732025

1974-
def from_json(data: dict, class_map: list = [puan.variable,AtLeast,AtMost,All,Any,Xor,Not,XNor,Imply]) -> typing.Any:
2026+
def from_json(data: dict, class_map: list = [puan.variable,AtLeast,AtMost,All,Any,Xor,ExactlyOne,Not,XNor,Imply]) -> typing.Any:
19752027

19762028
"""
19772029
Convert from json data to a proposition.

puan/ndarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2055,7 +2055,7 @@ def _vectors_from_prios(self, prios: typing.List[typing.Dict[str, int]]) -> nump
20552055
def select(
20562056
self,
20572057
*prios: typing.List[typing.Dict[str, int]],
2058-
solver: typing.Callable[[ge_polyhedron, typing.Dict[str, int]], typing.Iterable[typing.Tuple[typing.List[int], int, int]]] = None,
2058+
solver: typing.Callable[[ge_polyhedron, typing.Iterable[numpy.ndarray]], typing.Iterable[typing.Tuple[typing.Optional[numpy.ndarray], typing.Optional[int], int]]] = None,
20592059
) -> itertools.starmap:
20602060

20612061
"""

pyproject.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@ classifiers = [
1616
"Operating System :: OS Independent",
1717
]
1818
dependencies = [
19-
"numpy==1.22.3",
20-
"more-itertools==8.12.0",
21-
"maz==0.0.6",
22-
"dictdiffer==0.9.0",
23-
"puan-rspy==0.2.8"
19+
"numpy>=1.23.5",
20+
"more-itertools>=8.12.0",
21+
"maz>=0.0.6",
22+
"puan-rspy>=0.2.9"
2423
]
2524
keywords = ["combinatorial optimization", "milp", "mllp", "ilp", "linear-programming", "optimization"]
2625

tests/test_puan.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,6 +2025,116 @@ def dummy_solver(x, y):
20252025

20262026
assert list(actual) == list(expected)
20272027

2028+
def test_solve_select():
2029+
2030+
def dummy_solver(polyhedron, objectives):
2031+
return map(
2032+
lambda x: (x, x.sum(), 5),
2033+
numpy.ones((len(list(objectives)), polyhedron.A.shape[1]))
2034+
)
2035+
2036+
assert all(
2037+
itertools.starmap(
2038+
lambda model, expected, inc_virt: next(model.solve([{}], solver=dummy_solver, include_virtual_variables=inc_virt))[0] == expected,
2039+
[
2040+
(
2041+
pg.All(
2042+
pg.Any(*"ab"),
2043+
pg.Any(*"xy"),
2044+
),
2045+
{
2046+
"a": 1,
2047+
"b": 1,
2048+
"x": 1,
2049+
"y": 1,
2050+
},
2051+
False,
2052+
),
2053+
(
2054+
pg.All(
2055+
pg.Any(*"ab", variable="B"),
2056+
pg.Any(*"xy", variable="C"),
2057+
),
2058+
{
2059+
"B": 1,
2060+
"C": 1,
2061+
"a": 1,
2062+
"b": 1,
2063+
"x": 1,
2064+
"y": 1,
2065+
},
2066+
False,
2067+
),
2068+
(
2069+
pg.All(
2070+
pg.Any(*"ab"),
2071+
pg.Any(*"xy"),
2072+
),
2073+
{
2074+
"VARbe8d74d8fa4921a5b81b2aac8134ab779c2c68235100ac45f5b33779da3c647c": 1,
2075+
"VARf4ee25a75ae7daf40eefdd224ace61603dd2df6a77015889d190d878057b54d4": 1,
2076+
"a": 1,
2077+
"b": 1,
2078+
"x": 1,
2079+
"y": 1,
2080+
},
2081+
True,
2082+
),
2083+
]
2084+
)
2085+
)
2086+
2087+
assert all(
2088+
itertools.starmap(
2089+
lambda model, expected, inc_virt: next(model.solve([{}], include_virtual_variables=inc_virt))[0] == expected,
2090+
[
2091+
(
2092+
pg.All(
2093+
pg.Any(*"ab"),
2094+
pg.Any(*"xy"),
2095+
),
2096+
{
2097+
"a": 1,
2098+
"b": 0,
2099+
"x": 1,
2100+
"y": 0,
2101+
},
2102+
False,
2103+
),
2104+
(
2105+
pg.All(
2106+
pg.Any(*"ab", variable="B"),
2107+
pg.Any(*"xy", variable="C"),
2108+
),
2109+
{
2110+
"B": 1,
2111+
"C": 1,
2112+
"a": 1,
2113+
"b": 0,
2114+
"x": 1,
2115+
"y": 0,
2116+
},
2117+
False,
2118+
),
2119+
(
2120+
pg.All(
2121+
pg.Any(*"ab"),
2122+
pg.Any(*"xy"),
2123+
),
2124+
{
2125+
"VARbe8d74d8fa4921a5b81b2aac8134ab779c2c68235100ac45f5b33779da3c647c": 1,
2126+
"VARf4ee25a75ae7daf40eefdd224ace61603dd2df6a77015889d190d878057b54d4": 1,
2127+
"a": 1,
2128+
"b": 0,
2129+
"x": 1,
2130+
"y": 0,
2131+
},
2132+
True,
2133+
),
2134+
]
2135+
)
2136+
)
2137+
20282138
def test_default_prio_vector_weights():
20292139

20302140
"""
@@ -2268,9 +2378,6 @@ def test_at_leasts():
22682378

22692379
with pytest.raises(Exception):
22702380
pg.AtLeast(value=1, propositions=None, variable="A")
2271-
2272-
with pytest.raises(Exception):
2273-
pg.AtLeast(propositions=[], value=1, variable="A")
22742381

22752382
with pytest.raises(Exception):
22762383
pg.AtLeast(propositions=["a"], value=1, sign=-2)

0 commit comments

Comments
 (0)