diff --git a/diskimage_builder/block_device/blockdevice.py b/diskimage_builder/block_device/blockdevice.py index 562dc335..30c2172d 100644 --- a/diskimage_builder/block_device/blockdevice.py +++ b/diskimage_builder/block_device/blockdevice.py @@ -20,7 +20,6 @@ import os import pickle import pprint import shutil -import sys import yaml from diskimage_builder.block_device.config import config_tree_to_graph @@ -370,18 +369,21 @@ class BlockDevice(object): logger.info("create() called") logger.debug("Using config [%s]", self.config) - rollback = [] # Create a new, empty state state = BlockDeviceState() try: dg, call_order = create_graph(self.config, self.params, state) for node in call_order: - node.create(rollback) + node.create() except Exception: logger.exception("Create failed; rollback initiated") - for rollback_cb in reversed(rollback): - rollback_cb() - sys.exit(1) + reverse_order = reversed(call_order) + for node in reverse_order: + node.rollback() + # save the state for debugging + state.save_state(self.state_json_file_name) + logger.error("Rollback complete, exiting") + raise # dump state and nodes, in order # XXX: we only dump the call_order (i.e. nodes) not the whole diff --git a/diskimage_builder/block_device/level0/localloop.py b/diskimage_builder/block_device/level0/localloop.py index 8281c07d..6d608403 100644 --- a/diskimage_builder/block_device/level0/localloop.py +++ b/diskimage_builder/block_device/level0/localloop.py @@ -100,15 +100,15 @@ class LocalLoopNode(NodeBase): """Because this is created without base, there are no edges.""" return ([], []) - def create(self, rollback): + def create(self): logger.debug("[%s] Creating loop on [%s] with size [%d]", self.name, self.filename, self.size) - rollback.append(lambda: image_delete(self.filename)) + self.add_rollback(image_delete, self.filename) image_create(self.filename, self.size) block_device = loopdev_attach(self.filename) - rollback.append(lambda: loopdev_detach(block_device)) + self.add_rollback(loopdev_detach, block_device) if 'blockdev' not in self.state: self.state['blockdev'] = {} diff --git a/diskimage_builder/block_device/level1/partition.py b/diskimage_builder/block_device/level1/partition.py index 4ee6835f..c6f812c5 100644 --- a/diskimage_builder/block_device/level1/partition.py +++ b/diskimage_builder/block_device/level1/partition.py @@ -65,5 +65,5 @@ class PartitionNode(NodeBase): edge_from.append(self.prev_partition.name) return (edge_from, edge_to) - def create(self, rollback): - self.partitioning.create(rollback) + def create(self): + self.partitioning.create() diff --git a/diskimage_builder/block_device/level1/partitioning.py b/diskimage_builder/block_device/level1/partitioning.py index 4a55c676..ff2b2bc7 100644 --- a/diskimage_builder/block_device/level1/partitioning.py +++ b/diskimage_builder/block_device/level1/partitioning.py @@ -132,7 +132,7 @@ class Partitioning(PluginBase): exec_sudo(["kpartx", "-avs", device_path]) - def create(self, rollback): + def create(self): # not this is NOT a node and this is not called directly! The # create() calls in the partition nodes this plugin has # created are calling back into this. diff --git a/diskimage_builder/block_device/level2/mkfs.py b/diskimage_builder/block_device/level2/mkfs.py index 8a2c18be..7bebb8ca 100644 --- a/diskimage_builder/block_device/level2/mkfs.py +++ b/diskimage_builder/block_device/level2/mkfs.py @@ -103,7 +103,7 @@ class FilesystemNode(NodeBase): edge_to = [] return (edge_from, edge_to) - def create(self, rollback): + def create(self): cmd = ["mkfs"] cmd.extend(['-t', self.type]) diff --git a/diskimage_builder/block_device/level3/mount.py b/diskimage_builder/block_device/level3/mount.py index 473c4b3b..43b42fa6 100644 --- a/diskimage_builder/block_device/level3/mount.py +++ b/diskimage_builder/block_device/level3/mount.py @@ -71,7 +71,7 @@ class MountPointNode(NodeBase): edge_from.append(self.base) return (edge_from, edge_to) - def create(self, rollback): + def create(self): logger.debug("mount called [%s]", self.mount_point) rel_mp = self.mount_point if self.mount_point[0] != '/' \ else self.mount_point[1:] diff --git a/diskimage_builder/block_device/level4/fstab.py b/diskimage_builder/block_device/level4/fstab.py index 7a30e7fd..38f614da 100644 --- a/diskimage_builder/block_device/level4/fstab.py +++ b/diskimage_builder/block_device/level4/fstab.py @@ -34,7 +34,7 @@ class FstabNode(NodeBase): edge_to = [] return (edge_from, edge_to) - def create(self, rollback): + def create(self): logger.debug("fstab create called [%s]", self.name) if 'fstab' not in self.state: diff --git a/diskimage_builder/block_device/plugin.py b/diskimage_builder/block_device/plugin.py index 1e2e0a5f..64fc199e 100644 --- a/diskimage_builder/block_device/plugin.py +++ b/diskimage_builder/block_device/plugin.py @@ -11,6 +11,7 @@ # under the License. import abc +import logging import six # @@ -18,6 +19,8 @@ import six # processing. This defines the abstract classes for both. # +logger = logging.getLogger(__name__) + @six.add_metaclass(abc.ABCMeta) class NodeBase(object): @@ -44,12 +47,42 @@ class NodeBase(object): def __init__(self, name, state): self.name = name self.state = state + self.rollbacks = [] def get_name(self): return self.name + def add_rollback(self, func, *args, **kwargs): + """Add a call for rollback + + Functions registered with this method will be called in + reverse-order in the case of failures during + :func:`Nodebase.create`. + + :param func: function to call + :param args: arguments + :param kwargs: keyword arguments + :return: None + """ + self.rollbacks.append((func, args, kwargs)) + + def rollback(self): + """Initiate rollback + + Call registered rollback in reverse order. This method is + called by the driver in the case of failures during + :func:`Nodebase.create`. + + :return None: + """ + # XXX: maybe ignore SystemExit so we always continue? + logger.debug("Calling rollback for %s", self.name) + for func, args, kwargs in reversed(self.rollbacks): + func(*args, **kwargs) + @abc.abstractmethod def get_edges(self): + """Return the dependencies/edges for this node This function will be called after all nodes are created (this @@ -75,22 +108,14 @@ class NodeBase(object): return @abc.abstractmethod - def create(self, rollback): + def create(self): """Main creation driver This is the main driver function. After the graph is linearised, each node has it's :func:`create` function called. - Arguments: - - :param rollback: A shared list of functions to be called in - the failure case. Nodes should only append to this list. - On failure, the callbacks will be processed in reverse - order. - :raises Exception: A failure should raise an exception. This - will initiate the rollback - + will initiate a rollback. See :func:`Nodebase.add_rollback`. :return: None """ return diff --git a/diskimage_builder/block_device/tests/config/rollback.yaml b/diskimage_builder/block_device/tests/config/rollback.yaml new file mode 100644 index 00000000..be34ebfb --- /dev/null +++ b/diskimage_builder/block_device/tests/config/rollback.yaml @@ -0,0 +1,29 @@ +- test_a: + name: test_node_a + rollback_one_arg: down + rollback_two_arg: you + +- test_b: + base: test_node_a + name: test_node_b + rollback_one_arg: let + rollback_two_arg: gonna + +- test_a: + base: test_node_b + name: test_node_aa + rollback_one_arg: never + rollback_two_arg: up + +- test_b: + base: test_node_aa + name: test_node_bb + rollback_one_arg: you + rollback_two_arg: give + +- test_a: + base: test_node_bb + name: test_node_aaa + rollback_one_arg: gonna + rollback_two_arg: never + trigger_rollback: yes \ No newline at end of file diff --git a/diskimage_builder/block_device/tests/plugin/test_a.py b/diskimage_builder/block_device/tests/plugin/test_a.py index 1b98d207..137e5e0a 100644 --- a/diskimage_builder/block_device/tests/plugin/test_a.py +++ b/diskimage_builder/block_device/tests/plugin/test_a.py @@ -22,22 +22,46 @@ logger = logging.getLogger(__name__) class TestANode(NodeBase): - def __init__(self, name, state): + def __init__(self, config, state, test_rollback): logger.debug("Create test 1") - super(TestANode, self).__init__(name, state) + super(TestANode, self).__init__(config['name'], state) + # might be a root node, so possibly no base + if 'base' in config: + self.base = config['base'] + # put something in the state for test_b to check for state['test_init_state'] = 'here' - def get_edges(self): - # this is like the loop node; it's a root and doesn't have a - # base - return ([], []) + # If we're doing rollback testing the config has some strings + # set for us + if test_rollback: + self.add_rollback(self.do_rollback, config['rollback_one_arg']) + self.add_rollback(self.do_rollback, config['rollback_two_arg']) + # see if we're the node who is going to fail + self.trigger_rollback = True if 'trigger_rollback' in config else False - def create(self, rollback): + def get_edges(self): + # may not have a base, if used as root node + to = [self.base] if hasattr(self, 'base') else [] + return (to, []) + + def do_rollback(self, string): + # We will check this after all rollbacks to make sure they ran + # in the right order + self.state['rollback_test'].append(string) + + def create(self): # put some fake entries into state self.state['test_a'] = {} self.state['test_a']['value'] = 'foo' self.state['test_a']['value2'] = 'bar' + + if self.trigger_rollback: + # The rollback test will append the strings to this as + # it unrolls, and we'll check it's value at the end + self.state['rollback_test'] = [] + raise RuntimeError("Rollback triggered") + return def umount(self): @@ -49,7 +73,9 @@ class TestA(PluginBase): def __init__(self, config, defaults, state): super(TestA, self).__init__() - self.node = TestANode(config['name'], state) + + test_rollback = True if 'test_rollback' in defaults else False + self.node = TestANode(config, state, test_rollback) def get_nodes(self): return [self.node] diff --git a/diskimage_builder/block_device/tests/plugin/test_b.py b/diskimage_builder/block_device/tests/plugin/test_b.py index 00ac4622..4efc1ced 100644 --- a/diskimage_builder/block_device/tests/plugin/test_b.py +++ b/diskimage_builder/block_device/tests/plugin/test_b.py @@ -22,10 +22,16 @@ logger = logging.getLogger(__name__) class TestBNode(NodeBase): - def __init__(self, name, state, base): + def __init__(self, config, state, test_rollback): logger.debug("Create test 1") - super(TestBNode, self).__init__(name, state) - self.base = base + super(TestBNode, self).__init__(config['name'], state) + self.base = config['base'] + + # If we're doing rollback testing the config has some strings + # set for us. + if test_rollback: + self.add_rollback(self.do_rollback, config['rollback_one_arg']) + self.add_rollback(self.do_rollback, config['rollback_two_arg']) def get_edges(self): # this should have been inserted by test_a before @@ -33,7 +39,12 @@ class TestBNode(NodeBase): assert self.state['test_init_state'] == 'here' return ([self.base], []) - def create(self, rollback): + def do_rollback(self, string): + # We will check this after all rollbacks to make sure they ran + # in the right order + self.state['rollback_test'].append(string) + + def create(self): self.state['test_b'] = {} self.state['test_b']['value'] = 'baz' return @@ -52,9 +63,9 @@ class TestB(PluginBase): def __init__(self, config, defaults, state): super(TestB, self).__init__() - self.node = TestBNode(config['name'], - state, - config['base']) + + test_rollback = True if 'test_rollback' in defaults else False + self.node = TestBNode(config, state, test_rollback) def get_nodes(self): return [self.node] diff --git a/diskimage_builder/block_device/tests/test_mount_order.py b/diskimage_builder/block_device/tests/test_mount_order.py index 617c3950..e2838ea8 100644 --- a/diskimage_builder/block_device/tests/test_mount_order.py +++ b/diskimage_builder/block_device/tests/test_mount_order.py @@ -33,8 +33,6 @@ class TestMountOrder(tc.TestGraphGeneration): graph, call_order = create_graph(config, self.fake_default_config, state) - rollback = [] - # build up some fake state so that we don't have to mock out # all the parent calls that would really make these values, as # we just want to test MountPointNode @@ -51,7 +49,7 @@ class TestMountOrder(tc.TestGraphGeneration): # XXX: do we even need to create? We could test the # sudo arguments from the mock in the below asserts # too - node.create(rollback) + node.create() # ensure that partitions are mounted in order root->var->var/log self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log']) diff --git a/diskimage_builder/block_device/tests/test_state.py b/diskimage_builder/block_device/tests/test_state.py index 639a5abd..92592796 100644 --- a/diskimage_builder/block_device/tests/test_state.py +++ b/diskimage_builder/block_device/tests/test_state.py @@ -123,3 +123,30 @@ class TestState(TestStateBase): bd_obj.cmd_cleanup) # XXX: figure out unit test for umount + + # Test ordering of rollback calls if create() fails + def test_rollback(self): + params = { + 'build-dir': self.build_dir.path, + 'config': self.get_config_file('rollback.yaml'), + 'test_rollback': True + } + + bd_obj = bd.BlockDevice(params) + bd_obj.cmd_init() + + # The config file has flags in that tell the last node to + # fail, which will trigger the rollback. + self.assertRaises(RuntimeError, bd_obj.cmd_create) + + # cmd_create should have persisted this to disk even after the + # failure + state_file = bd_obj.state_json_file_name + self.assertThat(state_file, FileExists()) + with codecs.open(state_file, encoding='utf-8', mode='r') as fd: + state = json.load(fd) + + # ensure the rollback was called in order + self.assertListEqual(state['rollback_test'], + ['never', 'gonna', 'give', 'you', 'up', + 'never', 'gonna', 'let', 'you', 'down'])