Move rollback into NodeBase object

Currently we pass a reference to a global "rollback" list to create()
to keep rollback functions.  Other nodes don't need to know about
global rollback state, and by passing by reference we're giving them
the chance to mess it up for everyone else.

Add a "add_rollback()" function in NodeBase for create() calls to
register rollback calls within themselves.  As they hit rollback
points they can add a new entry.  lambda v arguments is much of a
muchness -- but this is similar to the standard atexit() call so with
go with that pattern.  A new "rollback()" call is added that the
driver will invoke on each node as it works its way backwards in case
of failure.

On error, nodes will have rollback() called in reverse order (which
then calls registered rollbacks in reverse order).

A unit test is added to test rollback behaviour

Change-Id: I65214e72c7ef607dd08f750a6d32a0b10fe97ac3
This commit is contained in:
Ian Wienand 2017-06-02 10:57:06 +10:00
parent 09dee46579
commit 1d1e4ccb3e
13 changed files with 161 additions and 43 deletions

View file

@ -20,7 +20,6 @@ import os
import pickle
import pprint
import shutil
import sys
import yaml
from diskimage_builder.block_device.config import config_tree_to_graph
@ -370,18 +369,21 @@ class BlockDevice(object):
logger.info("create() called")
logger.debug("Using config [%s]", self.config)
rollback = []
# Create a new, empty state
state = BlockDeviceState()
try:
dg, call_order = create_graph(self.config, self.params, state)
for node in call_order:
node.create(rollback)
node.create()
except Exception:
logger.exception("Create failed; rollback initiated")
for rollback_cb in reversed(rollback):
rollback_cb()
sys.exit(1)
reverse_order = reversed(call_order)
for node in reverse_order:
node.rollback()
# save the state for debugging
state.save_state(self.state_json_file_name)
logger.error("Rollback complete, exiting")
raise
# dump state and nodes, in order
# XXX: we only dump the call_order (i.e. nodes) not the whole

View file

@ -100,15 +100,15 @@ class LocalLoopNode(NodeBase):
"""Because this is created without base, there are no edges."""
return ([], [])
def create(self, rollback):
def create(self):
logger.debug("[%s] Creating loop on [%s] with size [%d]",
self.name, self.filename, self.size)
rollback.append(lambda: image_delete(self.filename))
self.add_rollback(image_delete, self.filename)
image_create(self.filename, self.size)
block_device = loopdev_attach(self.filename)
rollback.append(lambda: loopdev_detach(block_device))
self.add_rollback(loopdev_detach, block_device)
if 'blockdev' not in self.state:
self.state['blockdev'] = {}

View file

@ -65,5 +65,5 @@ class PartitionNode(NodeBase):
edge_from.append(self.prev_partition.name)
return (edge_from, edge_to)
def create(self, rollback):
self.partitioning.create(rollback)
def create(self):
self.partitioning.create()

View file

@ -132,7 +132,7 @@ class Partitioning(PluginBase):
exec_sudo(["kpartx", "-avs", device_path])
def create(self, rollback):
def create(self):
# not this is NOT a node and this is not called directly! The
# create() calls in the partition nodes this plugin has
# created are calling back into this.

View file

@ -103,7 +103,7 @@ class FilesystemNode(NodeBase):
edge_to = []
return (edge_from, edge_to)
def create(self, rollback):
def create(self):
cmd = ["mkfs"]
cmd.extend(['-t', self.type])

View file

@ -71,7 +71,7 @@ class MountPointNode(NodeBase):
edge_from.append(self.base)
return (edge_from, edge_to)
def create(self, rollback):
def create(self):
logger.debug("mount called [%s]", self.mount_point)
rel_mp = self.mount_point if self.mount_point[0] != '/' \
else self.mount_point[1:]

View file

@ -34,7 +34,7 @@ class FstabNode(NodeBase):
edge_to = []
return (edge_from, edge_to)
def create(self, rollback):
def create(self):
logger.debug("fstab create called [%s]", self.name)
if 'fstab' not in self.state:

View file

@ -11,6 +11,7 @@
# under the License.
import abc
import logging
import six
#
@ -18,6 +19,8 @@ import six
# processing. This defines the abstract classes for both.
#
logger = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class NodeBase(object):
@ -44,12 +47,42 @@ class NodeBase(object):
def __init__(self, name, state):
self.name = name
self.state = state
self.rollbacks = []
def get_name(self):
return self.name
def add_rollback(self, func, *args, **kwargs):
"""Add a call for rollback
Functions registered with this method will be called in
reverse-order in the case of failures during
:func:`Nodebase.create`.
:param func: function to call
:param args: arguments
:param kwargs: keyword arguments
:return: None
"""
self.rollbacks.append((func, args, kwargs))
def rollback(self):
"""Initiate rollback
Call registered rollback in reverse order. This method is
called by the driver in the case of failures during
:func:`Nodebase.create`.
:return None:
"""
# XXX: maybe ignore SystemExit so we always continue?
logger.debug("Calling rollback for %s", self.name)
for func, args, kwargs in reversed(self.rollbacks):
func(*args, **kwargs)
@abc.abstractmethod
def get_edges(self):
"""Return the dependencies/edges for this node
This function will be called after all nodes are created (this
@ -75,22 +108,14 @@ class NodeBase(object):
return
@abc.abstractmethod
def create(self, rollback):
def create(self):
"""Main creation driver
This is the main driver function. After the graph is
linearised, each node has it's :func:`create` function called.
Arguments:
:param rollback: A shared list of functions to be called in
the failure case. Nodes should only append to this list.
On failure, the callbacks will be processed in reverse
order.
:raises Exception: A failure should raise an exception. This
will initiate the rollback
will initiate a rollback. See :func:`Nodebase.add_rollback`.
:return: None
"""
return

View file

@ -0,0 +1,29 @@
- test_a:
name: test_node_a
rollback_one_arg: down
rollback_two_arg: you
- test_b:
base: test_node_a
name: test_node_b
rollback_one_arg: let
rollback_two_arg: gonna
- test_a:
base: test_node_b
name: test_node_aa
rollback_one_arg: never
rollback_two_arg: up
- test_b:
base: test_node_aa
name: test_node_bb
rollback_one_arg: you
rollback_two_arg: give
- test_a:
base: test_node_bb
name: test_node_aaa
rollback_one_arg: gonna
rollback_two_arg: never
trigger_rollback: yes

View file

@ -22,22 +22,46 @@ logger = logging.getLogger(__name__)
class TestANode(NodeBase):
def __init__(self, name, state):
def __init__(self, config, state, test_rollback):
logger.debug("Create test 1")
super(TestANode, self).__init__(name, state)
super(TestANode, self).__init__(config['name'], state)
# might be a root node, so possibly no base
if 'base' in config:
self.base = config['base']
# put something in the state for test_b to check for
state['test_init_state'] = 'here'
def get_edges(self):
# this is like the loop node; it's a root and doesn't have a
# base
return ([], [])
# If we're doing rollback testing the config has some strings
# set for us
if test_rollback:
self.add_rollback(self.do_rollback, config['rollback_one_arg'])
self.add_rollback(self.do_rollback, config['rollback_two_arg'])
# see if we're the node who is going to fail
self.trigger_rollback = True if 'trigger_rollback' in config else False
def create(self, rollback):
def get_edges(self):
# may not have a base, if used as root node
to = [self.base] if hasattr(self, 'base') else []
return (to, [])
def do_rollback(self, string):
# We will check this after all rollbacks to make sure they ran
# in the right order
self.state['rollback_test'].append(string)
def create(self):
# put some fake entries into state
self.state['test_a'] = {}
self.state['test_a']['value'] = 'foo'
self.state['test_a']['value2'] = 'bar'
if self.trigger_rollback:
# The rollback test will append the strings to this as
# it unrolls, and we'll check it's value at the end
self.state['rollback_test'] = []
raise RuntimeError("Rollback triggered")
return
def umount(self):
@ -49,7 +73,9 @@ class TestA(PluginBase):
def __init__(self, config, defaults, state):
super(TestA, self).__init__()
self.node = TestANode(config['name'], state)
test_rollback = True if 'test_rollback' in defaults else False
self.node = TestANode(config, state, test_rollback)
def get_nodes(self):
return [self.node]

View file

@ -22,10 +22,16 @@ logger = logging.getLogger(__name__)
class TestBNode(NodeBase):
def __init__(self, name, state, base):
def __init__(self, config, state, test_rollback):
logger.debug("Create test 1")
super(TestBNode, self).__init__(name, state)
self.base = base
super(TestBNode, self).__init__(config['name'], state)
self.base = config['base']
# If we're doing rollback testing the config has some strings
# set for us.
if test_rollback:
self.add_rollback(self.do_rollback, config['rollback_one_arg'])
self.add_rollback(self.do_rollback, config['rollback_two_arg'])
def get_edges(self):
# this should have been inserted by test_a before
@ -33,7 +39,12 @@ class TestBNode(NodeBase):
assert self.state['test_init_state'] == 'here'
return ([self.base], [])
def create(self, rollback):
def do_rollback(self, string):
# We will check this after all rollbacks to make sure they ran
# in the right order
self.state['rollback_test'].append(string)
def create(self):
self.state['test_b'] = {}
self.state['test_b']['value'] = 'baz'
return
@ -52,9 +63,9 @@ class TestB(PluginBase):
def __init__(self, config, defaults, state):
super(TestB, self).__init__()
self.node = TestBNode(config['name'],
state,
config['base'])
test_rollback = True if 'test_rollback' in defaults else False
self.node = TestBNode(config, state, test_rollback)
def get_nodes(self):
return [self.node]

View file

@ -33,8 +33,6 @@ class TestMountOrder(tc.TestGraphGeneration):
graph, call_order = create_graph(config, self.fake_default_config,
state)
rollback = []
# build up some fake state so that we don't have to mock out
# all the parent calls that would really make these values, as
# we just want to test MountPointNode
@ -51,7 +49,7 @@ class TestMountOrder(tc.TestGraphGeneration):
# XXX: do we even need to create? We could test the
# sudo arguments from the mock in the below asserts
# too
node.create(rollback)
node.create()
# ensure that partitions are mounted in order root->var->var/log
self.assertListEqual(state['mount_order'], ['/', '/var', '/var/log'])

View file

@ -123,3 +123,30 @@ class TestState(TestStateBase):
bd_obj.cmd_cleanup)
# XXX: figure out unit test for umount
# Test ordering of rollback calls if create() fails
def test_rollback(self):
params = {
'build-dir': self.build_dir.path,
'config': self.get_config_file('rollback.yaml'),
'test_rollback': True
}
bd_obj = bd.BlockDevice(params)
bd_obj.cmd_init()
# The config file has flags in that tell the last node to
# fail, which will trigger the rollback.
self.assertRaises(RuntimeError, bd_obj.cmd_create)
# cmd_create should have persisted this to disk even after the
# failure
state_file = bd_obj.state_json_file_name
self.assertThat(state_file, FileExists())
with codecs.open(state_file, encoding='utf-8', mode='r') as fd:
state = json.load(fd)
# ensure the rollback was called in order
self.assertListEqual(state['rollback_test'],
['never', 'gonna', 'give', 'you', 'up',
'never', 'gonna', 'let', 'you', 'down'])