diff --git a/requirements.txt b/requirements.txt index d6df3a4..73c0075 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy==1.15.1 -scipy==1.0.1 -ortools==6.9.5824 -packaging==16.8 \ No newline at end of file +numpy>=1.15.1 +scipy>=1.0.1 +ortools>=6.9.5824 +packaging>=16.8 \ No newline at end of file diff --git a/seriate.py b/seriate.py index 94305d5..d9859a2 100644 --- a/seriate.py +++ b/seriate.py @@ -11,8 +11,10 @@ ortools_version = Version(ortools.__version__) ortools6 = Version("6.0.0") <= ortools_version < Version("7") ortools7 = Version("7.0.0") <= ortools_version < Version("8") -if not ortools6 and not ortools7: - raise ImportError("No valid version of ortools installed. Please install ortools 6 or 7.") +ortools8 = Version("8.0.0") <= ortools_version < Version("9") +ortools9 = Version("9.0.0") <= ortools_version < Version("10") +if not ortools6 and not ortools7 and not ortools8 and not ortools9: + raise ImportError("No valid version of ortools installed. Please install ortools 6 or 7 or 8 or 9.") class IncompleteSolutionError(Exception): @@ -101,12 +103,12 @@ def _seriate(dists: numpy.ndarray, approximation_multiplier=1000, timeout=2.0) - if ortools6: routing = pywrapcp.RoutingModel(size + 1, 1, size) - elif ortools7: + elif ortools7 or ortools8 or ortools9: manager = pywrapcp.RoutingIndexManager(size + 1, 1, size) routing = pywrapcp.RoutingModel(manager) def dist_callback(x, y): - if ortools7: + if ortools7 or ortools8 or ortools9: x = manager.IndexToNode(x) y = manager.IndexToNode(y) if x == size or y == size or x == y: @@ -125,7 +127,7 @@ def dist_callback(x, y): routing.SetArcCostEvaluatorOfAllVehicles(dist_callback) search_parameters = pywrapcp.RoutingModel.DefaultSearchParameters() search_parameters.time_limit_ms = int(timeout * 1000) - elif ortools7: + elif ortools7 or ortools8 or ortools9: routing.SetArcCostEvaluatorOfAllVehicles(routing.RegisterTransitCallback(dist_callback)) search_parameters = pywrapcp.DefaultRoutingSearchParameters() search_parameters.time_limit.FromMilliseconds(int(timeout * 1000)) @@ -142,7 +144,7 @@ def dist_callback(x, y): while not routing.IsEnd(index): if ortools6: node = routing.IndexToNode(index) - elif ortools7: + elif ortools7 or ortools8 or ortools9: node = manager.IndexToNode(index) if node < size: route.append(node) diff --git a/setup.py b/setup.py index 1b52a96..c66b4cd 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ download_url="https://github.com/src-d/seriate", py_modules=["seriate"], keywords=["seriation"], - install_requires=["numpy>=1.0", "ortools>=6.7.4973,<8", "packaging>=16.0"], + install_requires=["numpy>=1.0", "ortools>=6.7.4973,<=9", "packaging>=16.0"], tests_require=["scipy>=1.0"], package_data={"": ["LICENSE.md", "README.md", "requirements.txt"]}, classifiers=[ @@ -31,5 +31,6 @@ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], )