Refactor mount-point sorting

Currently we keep a global list of mount-points defined in the
configuration and automatically setup dependencies between mount nodes
based on their global "mount order" (i.e. higher directories mount
first).

The current method for achieving this is roughly to add the mount
points to a dictionary indexed my mount-point, then at "get_edge()"
call build the sorted list ... unless it has already been built
because this gets called for every node.

It seems much simpler to simply keep a sorted list of the
MountPointNode objects as we add them.  We don't need to implement a
sorting algorithm then, we can just use sort() and implement __lt__
for the nodes.

I believe the existing mount-order unit testing is sufficient; I'm
struggling to find a valid configuration where the mount-order is
*not* correctly specified in the configuration graph.

Change-Id: Idc05cdf42d95e230b9906773aa2b4a3b0f075598
This commit is contained in:
Ian Wienand 2017-05-31 10:41:02 +10:00
parent b85de3cd9e
commit 35a1e7bee9
4 changed files with 27 additions and 72 deletions

View File

@ -20,17 +20,13 @@ from diskimage_builder.block_device.exception \
from diskimage_builder.block_device.plugin import NodeBase
from diskimage_builder.block_device.plugin import PluginBase
from diskimage_builder.block_device.utils import exec_sudo
from diskimage_builder.block_device.utils import sort_mount_points
logger = logging.getLogger(__name__)
# There is the need to collect all mount points to be able to
# sort them in a sensible way.
mount_points = {}
# The order of mounting and unmounting is important.
sorted_mount_points = None
sorted_mount_points = []
class MountPointNode(NodeBase):
@ -47,45 +43,30 @@ class MountPointNode(NodeBase):
setattr(self, pname, config[pname])
logger.debug("MountPoint created [%s]", self)
def get_node(self):
global mount_points
if self.mount_point in mount_points:
raise BlockDeviceSetupException(
"Mount point [%s] specified more than once"
% self.mount_point)
logger.debug("Insert node [%s]", self)
mount_points[self.mount_point] = self
return self
def __lt__(self, other):
# in words: if the other mount-point has us as it's
# parent, we come before it (less than it). e.g.
# /var < /var/log < /var/log/foo
return other.mount_point.startswith(self.mount_point)
def get_edges(self):
"""Insert all edges
After inserting all the nodes, the order of the mounting and
umounting can be computed. There is the need to mount
mount-points that contain other mount-points first.
Example: '/var' must be mounted before '/var/log'. If not the
second is not used for files at all.
The dependency edge is created in all cases from the base
element (typically a mkfs) and, if this is not the 'first'
mount-point, also depend on the mount point before. This
ensures that during mounting (and umounting) the correct
mount-point, an edge is created from the mount-point before in
"sorted order" (see :func:`sort_mount_points`). This ensures
that during mounting (and umounting) the globally correct
order is used.
"""
edge_from = []
edge_to = []
global mount_points
global sorted_mount_points
if sorted_mount_points is None:
logger.debug("Mount points [%s]", mount_points)
sorted_mount_points = sort_mount_points(mount_points.keys())
logger.info("Sorted mount points [%s]", sorted_mount_points)
# Look for the occurance in the list
mpi = sorted_mount_points.index(self.mount_point)
# If we are not first, add our parent in the global dependency
# list
mpi = sorted_mount_points.index(self)
if mpi > 0:
# If not the first: add also the dependency
dep = mount_points[sorted_mount_points[mpi - 1]]
dep = sorted_mount_points[mpi - 1]
edge_from.append(dep.name)
edge_from.append(self.base)
@ -125,20 +106,22 @@ class Mount(PluginBase):
def __init__(self, config, defaults):
super(Mount, self).__init__()
self.mount_points = {}
if 'mount-base' not in defaults:
raise BlockDeviceSetupException(
"Mount default config needs 'mount-base'")
self.mount_base = defaults['mount-base']
self.node = MountPointNode(defaults['mount-base'], config)
mp = MountPointNode(self.mount_base, config)
self.mount_points[mp.get_name()] = mp
# save this new node to the global mount-point list and
# re-order it.
global sorted_mount_points
mount_points = [x.mount_point for x in sorted_mount_points]
if self.node.mount_point in mount_points:
raise BlockDeviceSetupException(
"Mount point [%s] specified more than once"
% self.node.mount_point)
sorted_mount_points.append(self.node)
sorted_mount_points.sort()
logger.debug("Ordered mounts now: %s", sorted_mount_points)
def get_nodes(self):
global sorted_mount_points
assert sorted_mount_points is None
nodes = []
for _, mp in self.mount_points.items():
nodes.append(mp.get_node())
return nodes
return [self.node]

View File

@ -32,8 +32,7 @@ class TestConfig(TestBase):
import diskimage_builder.block_device.level2.mkfs
diskimage_builder.block_device.level2.mkfs.file_system_labels = set()
import diskimage_builder.block_device.level3.mount
diskimage_builder.block_device.level3.mount.mount_points = {}
diskimage_builder.block_device.level3.mount.sorted_mount_points = None
diskimage_builder.block_device.level3.mount.sorted_mount_points = []
class TestGraphGeneration(TestConfig):

View File

@ -123,23 +123,3 @@ def exec_sudo(cmd):
if proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode,
' '.join(sudo_cmd))
def sort_mount_points(mount_points):
logger.debug("sort_mount_points called [%s]", mount_points)
def insert_sorted(mp, sorted_mount_points):
if len(sorted_mount_points) == 0:
sorted_mount_points.append(mp)
return
for idx in range(0, len(sorted_mount_points)):
if sorted_mount_points[idx].startswith(mp):
sorted_mount_points.insert(idx, mp)
return
sorted_mount_points.append(mp)
sorted_mount_points = []
for mp in mount_points:
insert_sorted(mp, sorted_mount_points)
logger.debug("sort_mount_points result [%s]", sorted_mount_points)
return sorted_mount_points

View File

@ -15,7 +15,6 @@
from diskimage_builder.block_device.utils import parse_abs_size_spec
from diskimage_builder.block_device.utils import parse_rel_size_spec
from diskimage_builder.block_device.utils import sort_mount_points
import testtools
@ -48,9 +47,3 @@ class TestBlockDeviceUtils(testtools.TestCase):
"""Call parse_abs_size_spec with a completely broken unit spec"""
self.assertRaises(RuntimeError, parse_abs_size_spec, "_+!HuHi+-=")
def test_sort_mount_points(self):
"""Run sort_mount_points with a set of paths"""
smp = sort_mount_points(["/boot", "/", "/var/tmp", "/var"])
self.assertEqual(['/', '/boot', '/var', '/var/tmp'], smp)