Move rollback into NodeBase object
Currently we pass a reference to a global "rollback" list to create() to keep rollback functions. Other nodes don't need to know about global rollback state, and by passing by reference we're giving them the chance to mess it up for everyone else. Add a "add_rollback()" function in NodeBase for create() calls to register rollback calls within themselves. As they hit rollback points they can add a new entry. lambda v arguments is much of a muchness -- but this is similar to the standard atexit() call so with go with that pattern. A new "rollback()" call is added that the driver will invoke on each node as it works its way backwards in case of failure. On error, nodes will have rollback() called in reverse order (which then calls registered rollbacks in reverse order). A unit test is added to test rollback behaviour Change-Id: I65214e72c7ef607dd08f750a6d32a0b10fe97ac3
This commit is contained in:
parent
09dee46579
commit
1d1e4ccb3e
13 changed files with 161 additions and 43 deletions
|
@ -20,7 +20,6 @@ import os
|
|||
import pickle
|
||||
import pprint
|
||||
import shutil
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
from diskimage_builder.block_device.config import config_tree_to_graph
|
||||
|
@ -370,18 +369,21 @@ class BlockDevice(object):
|
|||
logger.info("create() called")
|
||||
logger.debug("Using config [%s]", self.config)
|
||||
|
||||
rollback = []
|
||||
# Create a new, empty state
|
||||
state = BlockDeviceState()
|
||||
try:
|
||||
dg, call_order = create_graph(self.config, self.params, state)
|
||||
for node in call_order:
|
||||
node.create(rollback)
|
||||
node.create()
|
||||
except Exception:
|
||||
logger.exception("Create failed; rollback initiated")
|
||||
for rollback_cb in reversed(rollback):
|
||||
rollback_cb()
|
||||
sys.exit(1)
|
||||
reverse_order = reversed(call_order)
|
||||
for node in reverse_order:
|
||||
node.rollback()
|
||||
# save the state for debugging
|
||||
state.save_state(self.state_json_file_name)
|
||||
logger.error("Rollback complete, exiting")
|
||||
raise
|
||||
|
||||
# dump state and nodes, in order
|
||||
# XXX: we only dump the call_order (i.e. nodes) not the whole
|
||||
|
|
|
@ -100,15 +100,15 @@ class LocalLoopNode(NodeBase):
|
|||
"""Because this is created without base, there are no edges."""
|
||||
return ([], [])
|
||||
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
logger.debug("[%s] Creating loop on [%s] with size [%d]",
|
||||
self.name, self.filename, self.size)
|
||||
|
||||
rollback.append(lambda: image_delete(self.filename))
|
||||
self.add_rollback(image_delete, self.filename)
|
||||
image_create(self.filename, self.size)
|
||||
|
||||
block_device = loopdev_attach(self.filename)
|
||||
rollback.append(lambda: loopdev_detach(block_device))
|
||||
self.add_rollback(loopdev_detach, block_device)
|
||||
|
||||
if 'blockdev' not in self.state:
|
||||
self.state['blockdev'] = {}
|
||||
|
|
|
@ -65,5 +65,5 @@ class PartitionNode(NodeBase):
|
|||
edge_from.append(self.prev_partition.name)
|
||||
return (edge_from, edge_to)
|
||||
|
||||
def create(self, rollback):
|
||||
self.partitioning.create(rollback)
|
||||
def create(self):
|
||||
self.partitioning.create()
|
||||
|
|
|
@ -132,7 +132,7 @@ class Partitioning(PluginBase):
|
|||
|
||||
exec_sudo(["kpartx", "-avs", device_path])
|
||||
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
# not this is NOT a node and this is not called directly! The
|
||||
# create() calls in the partition nodes this plugin has
|
||||
# created are calling back into this.
|
||||
|
|
|
@ -103,7 +103,7 @@ class FilesystemNode(NodeBase):
|
|||
edge_to = []
|
||||
return (edge_from, edge_to)
|
||||
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
cmd = ["mkfs"]
|
||||
|
||||
cmd.extend(['-t', self.type])
|
||||
|
|
|
@ -71,7 +71,7 @@ class MountPointNode(NodeBase):
|
|||
edge_from.append(self.base)
|
||||
return (edge_from, edge_to)
|
||||
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
logger.debug("mount called [%s]", self.mount_point)
|
||||
rel_mp = self.mount_point if self.mount_point[0] != '/' \
|
||||
else self.mount_point[1:]
|
||||
|
|
|
@ -34,7 +34,7 @@ class FstabNode(NodeBase):
|
|||
edge_to = []
|
||||
return (edge_from, edge_to)
|
||||
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
logger.debug("fstab create called [%s]", self.name)
|
||||
|
||||
if 'fstab' not in self.state:
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import six
|
||||
|
||||
#
|
||||
|
@ -18,6 +19,8 @@ import six
|
|||
# processing. This defines the abstract classes for both.
|
||||
#
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class NodeBase(object):
|
||||
|
@ -44,12 +47,42 @@ class NodeBase(object):
|
|||
def __init__(self, name, state):
|
||||
self.name = name
|
||||
self.state = state
|
||||
self.rollbacks = []
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
def add_rollback(self, func, *args, **kwargs):
|
||||
"""Add a call for rollback
|
||||
|
||||
Functions registered with this method will be called in
|
||||
reverse-order in the case of failures during
|
||||
:func:`Nodebase.create`.
|
||||
|
||||
:param func: function to call
|
||||
:param args: arguments
|
||||
:param kwargs: keyword arguments
|
||||
:return: None
|
||||
"""
|
||||
self.rollbacks.append((func, args, kwargs))
|
||||
|
||||
def rollback(self):
|
||||
"""Initiate rollback
|
||||
|
||||
Call registered rollback in reverse order. This method is
|
||||
called by the driver in the case of failures during
|
||||
:func:`Nodebase.create`.
|
||||
|
||||
:return None:
|
||||
"""
|
||||
# XXX: maybe ignore SystemExit so we always continue?
|
||||
logger.debug("Calling rollback for %s", self.name)
|
||||
for func, args, kwargs in reversed(self.rollbacks):
|
||||
func(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_edges(self):
|
||||
|
||||
"""Return the dependencies/edges for this node
|
||||
|
||||
This function will be called after all nodes are created (this
|
||||
|
@ -75,22 +108,14 @@ class NodeBase(object):
|
|||
return
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self, rollback):
|
||||
def create(self):
|
||||
"""Main creation driver
|
||||
|
||||
This is the main driver function. After the graph is
|
||||
linearised, each node has it's :func:`create` function called.
|
||||
|
||||
Arguments:
|
||||
|
||||
:param rollback: A shared list of functions to be called in
|
||||
the failure case. Nodes should only append to this list.
|
||||
On failure, the callbacks will be processed in reverse
|
||||
order.
|
||||
|
||||
:raises Exception: A failure should raise an exception. This
|
||||
will initiate the rollback
|
||||
|
||||
will initiate a rollback. See :func:`Nodebase.add_rollback`.
|
||||
:return: None
|
||||
"""
|
||||
return
|
||||
|
|
29
diskimage_builder/block_device/tests/config/rollback.yaml
Normal file
29
diskimage_builder/block_device/tests/config/rollback.yaml
Normal file
|
@ -0,0 +1,29 @@
|
|||
- test_a:
|
||||
name: test_node_a
|
||||
rollback_one_arg: down
|
||||
rollback_two_arg: you
|
||||
|
||||
- test_b:
|
||||
base: test_node_a
|
||||
name: test_node_b
|
||||
rollback_one_arg: let
|
||||
rollback_two_arg: gonna
|
||||
|
||||
- test_a:
|
||||
base: test_node_b
|
||||
name: test_node_aa
|
||||
rollback_one_arg: never
|
||||
rollback_two_arg: up
|
||||
|
||||
- test_b:
|
||||
base: test_node_aa
|
||||
name: test_node_bb
|
||||
rollback_one_arg: you
|
||||
rollback_two_arg: give
|
||||
|
||||
- test_a:
|
||||
base: test_node_bb
|
||||
name: test_node_aaa
|
||||
rollback_one_arg: gonna
|
||||
rollback_two_arg: never
|
||||
trigger_rollback: yes
|
|
@ -22,22 +22,46 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TestANode(NodeBase):
|
||||
def __init__(self, name, state):
|
||||
def __init__(self, config, state, test_rollback):
|
||||
logger.debug("Create test 1")
|
||||
super(TestANode, self).__init__(name, state)
|
||||
super(TestANode, self).__init__(config['name'], state)
|
||||
# might be a root node, so possibly no base
|
||||
if 'base' in config:
|
||||
self.base = config['base']
|
||||
|
||||
# put something in the state for test_b to check for
|
||||
state['test_init_state'] = 'here'
|
||||
|
||||
def get_edges(self):
|
||||
# this is like the loop node; it's a root and doesn't have a
|
||||
# base
|
||||
return ([], [])
|
||||
# If we're doing rollback testing the config has some strings
|
||||
# set for us
|
||||
if test_rollback:
|
||||
self.add_rollback(self.do_rollback, config['rollback_one_arg'])
|
||||
self.add_rollback(self.do_rollback, config['rollback_two_arg'])
|
||||
# see if we're the node who is going to fail
|
||||
self.trigger_rollback = True if 'trigger_rollback' in config else False
|
||||
|
||||
def create(self, rollback):
|
||||
def get_edges(self):
|
||||
# may not have a base, if used as root node
|
||||
to = [self.base] if hasattr(self, 'base') else []
|
||||
return (to, [])
|
||||
|
||||
def do_rollback(self, string):
|
||||
# We will check this after all rollbacks to make sure they ran
|
||||
# in the right order
|
||||
self.state['rollback_test'].append(string)
|
||||
|
||||
def create(self):
|
||||
# put some fake entries into state
|
||||
self.state['test_a'] = {}
|
||||
self.state['test_a']['value'] = 'foo'
|
||||
self.state['test_a']['value2'] = 'bar'
|
||||
|
||||
if self.trigger_rollback:
|
||||
# The rollback test will append the strings to this as
|
||||
# it unrolls, and we'll check it's value at the end
|
||||
self.state['rollback_test'] = []
|
||||
raise RuntimeError("Rollback triggered")
|
||||
|
||||
return
|
||||
|
||||
def umount(self):
|
||||
|
@ -49,7 +73,9 @@ class TestA(PluginBase):
|
|||
|
||||
def __init__(self, config, defaults, state):
|
||||
super(TestA, self).__init__()
|
||||
self.node = TestANode(config['name'], state)
|
||||
|
||||
test_rollback = True if 'test_rollback' in defaults else False
|
||||
self.node = TestANode(config, state, test_rollback)
|
||||
|
||||
def get_nodes(self):
|
||||
return [self.node]
|
||||
|
|
|
@ -22,10 +22,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TestBNode(NodeBase):
|
||||
def __init__(self, name, state, base):
|
||||
def __init__(self, config, state, test_rollback):
|
||||
logger.debug("Create test 1")
|
||||
super(TestBNode, self).__init__(name, state)
|
||||
self.base = base
|
||||
super(TestBNode, self).__init__(config['name'], state)
|
||||
self.base = config['base']
|
||||
|
||||
# If we're doing rollback testing the config has some strings
|
||||
# set for us.
|
||||
if test_rollback:
|
||||
self.add_rollback(self.do_rollback, config['rollback_one_arg'])
|
||||
self.add_rollback(self.do_rollback, config['rollback_two_arg'])
|
||||
|
||||
def get_edges(self):
|
||||
# this should have been inserted by test_a before
|
||||
|
@ -33,7 +39,12 @@ class TestBNode(NodeBase):
|
|||
assert self.state['test_init_state'] == 'here'
|
||||
return ([self.base], [])
|
||||
|
||||
def create(self, rollback):
|
||||
def do_rollback(self, string):
|
||||
# We will check this after all rollbacks to make sure they ran
|
||||
# in the right order
|
||||
self.state['rollback_test'].append(string)
|
||||
|
||||
def create(self):
|
||||
self.state['test_b'] = {}
|
||||
self.state['test_b']['value'] = 'baz'
|
||||
return
|
||||
|
@ -52,9 +63,9 @@ class TestB(PluginBase):
|
|||
|
||||
def __init__(self, config, defaults, state):
|
||||
super(TestB, self).__init__()
|
||||
self.node = TestBNode(config['name'],
|
||||
state,
|
||||
config['base'])
|
||||
|
||||
test_rollback = True if 'test_rollback' in defaults else False
|
||||
self.node = TestBNode(config, state, test_rollback)
|
||||
|
||||
def get_nodes(self):
|
||||
return [self.node]
|
||||
|
|
|
@ -33,8 +33,6 @@ class TestMountOrder(tc.TestGraphGeneration):
|
|||
graph, call_order = create_graph(config, self.fake_default_config,
|
||||
state)
|
||||
|
||||
rollback = []
|
||||
|
||||
# build up some fake state so that we don't have to mock out
|
||||
# all the parent calls that would really make these values, as
|
||||
# we just want to test MountPointNode
|
||||
|
@ -51,7 +49,7 @@ class TestMountOrder(tc.TestGraphGeneration):
|
|||
# XXX: do we even need to create? We could test the
|
||||
# sudo arguments from the mock in the below asserts
|
||||
# too
|
||||
node.create(rollback)
|
||||
node.create()
|
||||
|
||||
# ensure that partitions are mounted in order root->var->var/log
|
||||
self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log'])
|
||||
|
|
|
@ -123,3 +123,30 @@ class TestState(TestStateBase):
|
|||
bd_obj.cmd_cleanup)
|
||||
|
||||
# XXX: figure out unit test for umount
|
||||
|
||||
# Test ordering of rollback calls if create() fails
|
||||
def test_rollback(self):
|
||||
params = {
|
||||
'build-dir': self.build_dir.path,
|
||||
'config': self.get_config_file('rollback.yaml'),
|
||||
'test_rollback': True
|
||||
}
|
||||
|
||||
bd_obj = bd.BlockDevice(params)
|
||||
bd_obj.cmd_init()
|
||||
|
||||
# The config file has flags in that tell the last node to
|
||||
# fail, which will trigger the rollback.
|
||||
self.assertRaises(RuntimeError, bd_obj.cmd_create)
|
||||
|
||||
# cmd_create should have persisted this to disk even after the
|
||||
# failure
|
||||
state_file = bd_obj.state_json_file_name
|
||||
self.assertThat(state_file, FileExists())
|
||||
with codecs.open(state_file, encoding='utf-8', mode='r') as fd:
|
||||
state = json.load(fd)
|
||||
|
||||
# ensure the rollback was called in order
|
||||
self.assertListEqual(state['rollback_test'],
|
||||
['never', 'gonna', 'give', 'you', 'up',
|
||||
'never', 'gonna', 'let', 'you', 'down'])
|
||||
|
|
Loading…
Reference in a new issue