Add state to NodeBase class

Making the global state reference a defined part of the node makes
some parts of the block device processing easier and removes the need
for other global values.

The state is passed to PluginNodeBase.__init__() and expected to be
passed into all nodes as they are created.  NodeBase.__init__() is
updated with the new paramater 'state'.

The parameter is removed from the create() call as nodes can simply
reference it at any point as "self.state".

This is similar to 1cdc8b20373c5d582ea928cfd7334469ff36dbce, except it
is based on I68840594a34af28d41d9522addcfd830bd203b97 which loads the
node-list from pickled state for later cmd_* calls.  Thus we only
build the state *once*, at cmd_create() time as we build the node
list.

Change-Id: I468dbf5134947629f125504513703d6f2cdace59
This commit is contained in:
Ian Wienand 2017-06-01 14:31:49 +10:00
parent e82e0097a9
commit 824a9e91c4
14 changed files with 117 additions and 99 deletions

View File

@ -44,18 +44,19 @@ def _load_json(file_name):
class BlockDeviceState(collections.MutableMapping):
"""The global state singleton
An reference to an instance of this object is passed between nodes
as a global repository. It wraps a single dictionary "state"
and provides a few helper functions.
An reference to an instance of this object is saved into nodes as
a global repository. It wraps a single dictionary "state" and
provides a few helper functions.
This is used in two contexts:
The state ends up used in two contexts:
- The state is built by the :func:`NodeBase.create` commands as
called during :func:`BlockDevice.cmd_create`. It is then
persisted to disk by :func:`save_state`
- The node list (including this state) is pickled and dumped
between cmd_create() and later cmd_* calls that need to call
the nodes.
- Later calls (cleanup, umount, etc) load the state dictionary
from disk and are thus passed the full state.
- Some other cmd_* calls, such as cmd_writefstab, only need
access to values inside the state and not the whole node list,
and load it from the json dump created after cmd_create()
"""
# XXX:
# - we could implement getters/setters such that if loaded from
@ -373,9 +374,9 @@ class BlockDevice(object):
# Create a new, empty state
state = BlockDeviceState()
try:
dg, call_order = create_graph(self.config, self.params)
dg, call_order = create_graph(self.config, self.params, state)
for node in call_order:
node.create(state, rollback)
node.create(rollback)
except Exception:
logger.exception("Create failed; rollback initiated")
for rollback_cb in reversed(rollback):

View File

@ -142,13 +142,15 @@ def config_tree_to_graph(config):
return output
def create_graph(config, default_config):
def create_graph(config, default_config, state):
"""Generate configuration digraph
Generate the configuration digraph from the config
:param config: graph configuration file
:param default_config: default parameters (from --params)
:param state: reference to global state dictionary.
Passed to :func:`PluginBase.__init__`
:return: tuple with the graph object (a :class:`nx.Digraph`),
ordered list of :class:`NodeBase` objects
@ -175,7 +177,7 @@ def create_graph(config, default_config):
("Config element [%s] is not implemented" % cfg_obj_name))
plugin = _extensions[cfg_obj_name].plugin
assert issubclass(plugin, PluginBase)
cfg_obj = plugin(cfg_obj_val, default_config)
cfg_obj = plugin(cfg_obj_val, default_config, state)
# Ask the plugin for the nodes it would like to insert
# into the graph. Some plugins, such as partitioning,

View File

@ -80,10 +80,10 @@ class LocalLoopNode(NodeBase):
This class handles local loop devices that can be used
for VM image installation.
"""
def __init__(self, config, default_config):
def __init__(self, config, default_config, state):
logger.debug("Creating LocalLoop object; config [%s] "
"default_config [%s]", config, default_config)
super(LocalLoopNode, self).__init__(config['name'])
super(LocalLoopNode, self).__init__(config['name'], state)
if 'size' in config:
self.size = parse_abs_size_spec(config['size'])
logger.debug("Image size [%s]", self.size)
@ -100,7 +100,7 @@ class LocalLoopNode(NodeBase):
"""Because this is created without base, there are no edges."""
return ([], [])
def create(self, state, rollback):
def create(self, rollback):
logger.debug("[%s] Creating loop on [%s] with size [%d]",
self.name, self.filename, self.size)
@ -110,10 +110,10 @@ class LocalLoopNode(NodeBase):
block_device = loopdev_attach(self.filename)
rollback.append(lambda: loopdev_detach(block_device))
if 'blockdev' not in state:
state['blockdev'] = {}
if 'blockdev' not in self.state:
self.state['blockdev'] = {}
state['blockdev'][self.name] = {"device": block_device,
self.state['blockdev'][self.name] = {"device": block_device,
"image": self.filename}
logger.debug("Created loop name [%s] device [%s] image [%s]",
self.name, block_device, self.filename)
@ -131,9 +131,9 @@ class LocalLoopNode(NodeBase):
class LocalLoop(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(LocalLoop, self).__init__()
self.node = LocalLoopNode(config, defaults)
self.node = LocalLoopNode(config, defaults, state)
def get_nodes(self):
return [self.node]

View File

@ -25,9 +25,9 @@ class PartitionNode(NodeBase):
flag_boot = 1
flag_primary = 2
def __init__(self, config, parent, prev_partition):
def __init__(self, config, state, parent, prev_partition):
super(PartitionNode, self).__init__(config['name'])
super(PartitionNode, self).__init__(config['name'], state)
self.base = config['base']
self.partitioning = parent
@ -65,5 +65,5 @@ class PartitionNode(NodeBase):
edge_from.append(self.prev_partition.name)
return (edge_from, edge_to)
def create(self, state, rollback):
self.partitioning.create(state, rollback)
def create(self, rollback):
self.partitioning.create(rollback)

View File

@ -32,10 +32,15 @@ logger = logging.getLogger(__name__)
class Partitioning(PluginBase):
def __init__(self, config, default_config):
def __init__(self, config, default_config, state):
logger.debug("Creating Partitioning object; config [%s]", config)
super(Partitioning, self).__init__()
# Unlike other PluginBase we are somewhat persistent, as the
# partition nodes call back to us (see create() below). We
# need to keep this reference.
self.state = state
# Because using multiple partitions of one base is done
# within one object, there is the need to store a flag if the
# creation of the partitions was already done.
@ -76,7 +81,7 @@ class Partitioning(PluginBase):
prev_partition = None
for part_cfg in config['partitions']:
np = PartitionNode(part_cfg, self, prev_partition)
np = PartitionNode(part_cfg, state, self, prev_partition)
self.partitions.append(np)
prev_partition = np
@ -127,12 +132,12 @@ class Partitioning(PluginBase):
exec_sudo(["kpartx", "-avs", device_path])
def create(self, state, rollback):
def create(self, rollback):
# not this is NOT a node and this is not called directly! The
# create() calls in the partition nodes this plugin has
# created are calling back into this.
image_path = state['blockdev'][self.base]['image']
device_path = state['blockdev'][self.base]['device']
image_path = self.state['blockdev'][self.base]['image']
device_path = self.state['blockdev'][self.base]['device']
logger.info("Creating partition on [%s] [%s]", self.base, image_path)
# This is a bit of a hack. Each of the partitions is actually
@ -166,7 +171,7 @@ class Partitioning(PluginBase):
logger.debug("Create partition [%s] [%d]",
part_name, part_no)
partition_device_name = device_path + "p%d" % part_no
state['blockdev'][part_name] \
self.state['blockdev'][part_name] \
= {'device': partition_device_name}
partition_devices.add(partition_device_name)

View File

@ -43,9 +43,9 @@ file_system_max_label_length = {
class FilesystemNode(NodeBase):
def __init__(self, config):
def __init__(self, config, state):
logger.debug("Create filesystem object; config [%s]", config)
super(FilesystemNode, self).__init__(config['name'])
super(FilesystemNode, self).__init__(config['name'], state)
# Parameter check (mandatory)
for pname in ['base', 'type']:
@ -102,7 +102,7 @@ class FilesystemNode(NodeBase):
edge_to = []
return (edge_from, edge_to)
def create(self, state, rollback):
def create(self, rollback):
cmd = ["mkfs"]
cmd.extend(['-t', self.type])
@ -121,17 +121,17 @@ class FilesystemNode(NodeBase):
if self.type in ('ext2', 'ext3', 'ext4', 'xfs'):
cmd.append('-q')
if 'blockdev' not in state:
state['blockdev'] = {}
device = state['blockdev'][self.base]['device']
if 'blockdev' not in self.state:
self.state['blockdev'] = {}
device = self.state['blockdev'][self.base]['device']
cmd.append(device)
logger.debug("Creating fs command [%s]", cmd)
exec_sudo(cmd)
if 'filesys' not in state:
state['filesys'] = {}
state['filesys'][self.name] \
if 'filesys' not in self.state:
self.state['filesys'] = {}
self.state['filesys'][self.name] \
= {'uuid': self.uuid, 'label': self.label,
'fstype': self.type, 'opts': self.opts,
'device': device}
@ -144,10 +144,10 @@ class Mkfs(PluginBase):
systems.
"""
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(Mkfs, self).__init__()
self.filesystems = {}
fs = FilesystemNode(config)
fs = FilesystemNode(config, state)
self.filesystems[fs.get_name()] = fs
def get_nodes(self):

View File

@ -31,8 +31,8 @@ sorted_mount_points = []
class MountPointNode(NodeBase):
def __init__(self, mount_base, config):
super(MountPointNode, self).__init__(config['name'])
def __init__(self, mount_base, config, state):
super(MountPointNode, self).__init__(config['name'], state)
# Parameter check
self.mount_base = mount_base
@ -72,7 +72,7 @@ class MountPointNode(NodeBase):
edge_from.append(self.base)
return (edge_from, edge_to)
def create(self, state, rollback):
def create(self, rollback):
logger.debug("mount called [%s]", self.mount_point)
rel_mp = self.mount_point if self.mount_point[0] != '/' \
else self.mount_point[1:]
@ -82,17 +82,17 @@ class MountPointNode(NodeBase):
# file system tree.
exec_sudo(['mkdir', '-p', mount_point])
logger.info("Mounting [%s] to [%s]", self.name, mount_point)
exec_sudo(["mount", state['filesys'][self.base]['device'],
exec_sudo(["mount", self.state['filesys'][self.base]['device'],
mount_point])
if 'mount' not in state:
state['mount'] = {}
state['mount'][self.mount_point] \
if 'mount' not in self.state:
self.state['mount'] = {}
self.state['mount'][self.mount_point] \
= {'name': self.name, 'base': self.base, 'path': mount_point}
if 'mount_order' not in state:
state['mount_order'] = []
state['mount_order'].append(self.mount_point)
if 'mount_order' not in self.state:
self.state['mount_order'] = []
self.state['mount_order'].append(self.mount_point)
def umount(self, state):
logger.info("Called for [%s]", self.name)
@ -103,13 +103,13 @@ class MountPointNode(NodeBase):
class Mount(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(Mount, self).__init__()
if 'mount-base' not in defaults:
raise BlockDeviceSetupException(
"Mount default config needs 'mount-base'")
self.node = MountPointNode(defaults['mount-base'], config)
self.node = MountPointNode(defaults['mount-base'], config, state)
# save this new node to the global mount-point list and
# re-order it.

View File

@ -22,8 +22,8 @@ logger = logging.getLogger(__name__)
class FstabNode(NodeBase):
def __init__(self, config, params):
super(FstabNode, self).__init__(config['name'])
def __init__(self, config, state):
super(FstabNode, self).__init__(config['name'], state)
self.base = config['base']
self.options = config.get('options', 'defaults')
self.dump_freq = config.get('dump-freq', 0)
@ -34,13 +34,13 @@ class FstabNode(NodeBase):
edge_to = []
return (edge_from, edge_to)
def create(self, state, rollback):
def create(self, rollback):
logger.debug("fstab create called [%s]", self.name)
if 'fstab' not in state:
state['fstab'] = {}
if 'fstab' not in self.state:
self.state['fstab'] = {}
state['fstab'][self.base] = {
self.state['fstab'][self.base] = {
'name': self.name,
'base': self.base,
'options': self.options,
@ -50,10 +50,10 @@ class FstabNode(NodeBase):
class Fstab(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(Fstab, self).__init__()
self.node = FstabNode(config, defaults)
self.node = FstabNode(config, state)
def get_nodes(self):
return [self.node]

View File

@ -32,7 +32,7 @@ class NodeBase(object):
Every node has a unique string ``name``. This is its key in the
graph and used for edge relationships. Implementations must
ensure they initialize it; e.g.
ensure they initalize it; e.g.
.. code-block:: python
@ -41,8 +41,9 @@ class NodeBase(object):
super(FooNode, self).__init__(name)
"""
def __init__(self, name):
def __init__(self, name, state):
self.name = name
self.state = state
def get_name(self):
return self.name
@ -74,7 +75,7 @@ class NodeBase(object):
return
@abc.abstractmethod
def create(self, state, rollback):
def create(self, rollback):
"""Main creation driver
This is the main driver function. After the graph is
@ -82,12 +83,6 @@ class NodeBase(object):
Arguments:
:param state: A shared dictionary of prior results. This
dictionary is passed by reference to each call, meaning any
entries inserted will be available to subsequent :func:`create`
calls of following nodes. The ``state`` dictionary will be
saved and available to other calls.
:param rollback: A shared list of functions to be called in
the failure case. Nodes should only append to this list.
On failure, the callbacks will be processed in reverse
@ -164,13 +159,16 @@ class PluginBase(object):
argument_a: bar
argument_b: baz
The ``__init__`` function will be passed two arguments:
The ``__init__`` function will be passed three arguments:
``config``
The full configuration dictionary for the entry.
A unique ``name`` entry can be assumed. In most cases
a ``base`` entry will be present giving the parent node
(see :func:`NodeBase.get_edges`).
``state``
A reference to the gobal state dictionary. This should be
passed to :func:`NodeBase.__init__` on node creation
``defaults``
The global defaults dictionary (see ``--params``)
@ -183,9 +181,9 @@ class PluginBase(object):
class Foo(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(Foo, self).__init__()
self.node = FooNode(config.name, ...)
self.node = FooNode(config.name, state, ...)
def get_nodes(self):
return [self.node]

View File

@ -22,20 +22,22 @@ logger = logging.getLogger(__name__)
class TestANode(NodeBase):
def __init__(self, name):
def __init__(self, name, state):
logger.debug("Create test 1")
super(TestANode, self).__init__(name)
super(TestANode, self).__init__(name, state)
# put something in the state for test_b to check for
state['test_init_state'] = 'here'
def get_edges(self):
# this is like the loop node; it's a root and doesn't have a
# base
return ([], [])
def create(self, state, rollback):
def create(self, rollback):
# put some fake entries into state
state['test_a'] = {}
state['test_a']['value'] = 'foo'
state['test_a']['value2'] = 'bar'
self.state['test_a'] = {}
self.state['test_a']['value'] = 'foo'
self.state['test_a']['value2'] = 'bar'
return
def umount(self, state):
@ -45,9 +47,9 @@ class TestANode(NodeBase):
class TestA(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(TestA, self).__init__()
self.node = TestANode(config['name'])
self.node = TestANode(config['name'], state)
def get_nodes(self):
return [self.node]

View File

@ -22,17 +22,20 @@ logger = logging.getLogger(__name__)
class TestBNode(NodeBase):
def __init__(self, name, base):
def __init__(self, name, state, base):
logger.debug("Create test 1")
super(TestBNode, self).__init__(name)
super(TestBNode, self).__init__(name, state)
self.base = base
def get_edges(self):
# this should have been inserted by test_a before
# we are called
assert self.state['test_init_state'] == 'here'
return ([self.base], [])
def create(self, state, rollback):
state['test_b'] = {}
state['test_b']['value'] = 'baz'
def create(self, rollback):
self.state['test_b'] = {}
self.state['test_b']['value'] = 'baz'
return
def umount(self, state):
@ -44,9 +47,10 @@ class TestBNode(NodeBase):
class TestB(PluginBase):
def __init__(self, config, defaults):
def __init__(self, config, defaults, state):
super(TestB, self).__init__()
self.node = TestBNode(config['name'],
state,
config['base'])
def get_nodes(self):

View File

@ -104,7 +104,7 @@ class TestCreateGraph(TestGraphGeneration):
self.assertRaisesRegex(BlockDeviceSetupException,
"Edge not defined: this_is_not_a_node",
create_graph,
config, self.fake_default_config)
config, self.fake_default_config, {})
# Test a graph with bad edge pointing to an invalid node
def test_duplicate_name(self):
@ -113,13 +113,13 @@ class TestCreateGraph(TestGraphGeneration):
"Duplicate node name: "
"this_is_a_duplicate",
create_graph,
config, self.fake_default_config)
config, self.fake_default_config, {})
# Test digraph generation from deep_graph config file
def test_deep_graph_generator(self):
config = self.load_config_file('deep_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config)
graph, call_order = create_graph(config, self.fake_default_config, {})
call_order_list = [n.name for n in call_order]
@ -136,7 +136,7 @@ class TestCreateGraph(TestGraphGeneration):
def test_multiple_partitions_graph_generator(self):
config = self.load_config_file('multiple_partitions_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config)
graph, call_order = create_graph(config, self.fake_default_config, {})
call_order_list = [n.name for n in call_order]
# The sort creating call_order_list is unstable.

View File

@ -28,9 +28,16 @@ class TestMountOrder(tc.TestGraphGeneration):
config = self.load_config_file('multiple_partitions_graph.yaml')
graph, call_order = create_graph(config, self.fake_default_config)
state = {}
graph, call_order = create_graph(config, self.fake_default_config,
state)
rollback = []
# build up some fake state so that we don't have to mock out
# all the parent calls that would really make these values, as
# we just want to test MountPointNode
state['filesys'] = {}
state['filesys']['mkfs_root'] = {}
state['filesys']['mkfs_root']['device'] = 'fake'
@ -39,14 +46,12 @@ class TestMountOrder(tc.TestGraphGeneration):
state['filesys']['mkfs_var_log'] = {}
state['filesys']['mkfs_var_log']['device'] = 'fake'
rollback = []
for node in call_order:
if isinstance(node, MountPointNode):
# XXX: do we even need to create? We could test the
# sudo arguments from the mock in the below asserts
# too
node.create(state, rollback)
node.create(rollback)
# ensure that partitions are mounted in order root->var->var/log
self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log'])

View File

@ -72,7 +72,8 @@ class TestState(TestStateBase):
self.assertDictEqual(state,
{'test_a': {'value': 'foo',
'value2': 'bar'},
'test_b': {'value': 'baz'}})
'test_b': {'value': 'baz'},
'test_init_state': 'here'})
pickle_file = bd_obj.node_pickle_file_name
self.assertThat(pickle_file, FileExists())