aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/leader/election/election.go
blob: 6f83298735f05af22fb0cf85769500d57406b13e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package election

import (
	"context"
	"errors"
	"log/slog"
	"time"

	"github.com/bsm/redislock"
)

var defaultLeaderElectionOptions = leaderElectionOptions{
	lockKey: "ibd-leader-election",
	lockTTL: 10 * time.Second,
}

func RunOrDie(
	ctx context.Context,
	client redislock.RedisClient,
	onLeader func(context.Context),
	opts ...LeaderElectionOption,
) {
	o := defaultLeaderElectionOptions
	for _, opt := range opts {
		opt(&o)
	}

	locker := redislock.New(client)

	// Election loop
	for {
		lock, err := locker.Obtain(ctx, o.lockKey, o.lockTTL, nil)
		if errors.Is(err, redislock.ErrNotObtained) {
			// Another instance is the leader
		} else if err != nil {
			slog.ErrorContext(ctx, "failed to obtain lock", "error", err)
		} else {
			// We are the leader
			slog.DebugContext(ctx, "elected leader")
			runLeader(ctx, lock, onLeader, o)
		}

		// Sleep for a bit before trying again
		timer := time.NewTimer(o.lockTTL / 5)
		select {
		case <-ctx.Done():
			if !timer.Stop() {
				<-timer.C
			}
			return
		case <-timer.C:
		}
	}
}

func runLeader(
	ctx context.Context,
	lock *redislock.Lock,
	onLeader func(context.Context),
	o leaderElectionOptions,
) {
	// A context that is canceled when the leader loses the lock
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// Release the lock when done
	defer func() {
		// Create new context without cancel if the original context is already canceled
		relCtx := ctx
		if ctx.Err() != nil {
			relCtx = context.WithoutCancel(ctx)
		}

		// Add a timeout to the release context
		relCtx, cancel := context.WithTimeout(relCtx, o.lockTTL)
		defer cancel()

		if err := lock.Release(relCtx); err != nil {
			slog.Error("failed to release lock", "error", err)
		}
	}()

	// Run the leader code
	go func(ctx context.Context) {
		onLeader(ctx)

		// If the leader code returns, cancel the context to release the lock
		cancel()
	}(ctx)

	// Refresh the lock periodically
	ticker := time.NewTicker(o.lockTTL / 10)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			err := lock.Refresh(ctx, o.lockTTL, nil)
			if errors.Is(err, redislock.ErrNotObtained) || errors.Is(err, redislock.ErrLockNotHeld) {
				slog.ErrorContext(ctx, "leadership lost", "error", err)
				return
			} else if err != nil {
				slog.ErrorContext(ctx, "failed to refresh lock", "error", err)
			}
		case <-ctx.Done():
			return
		}
	}
}

type leaderElectionOptions struct {
	lockKey string
	lockTTL time.Duration
}

type LeaderElectionOption func(*leaderElectionOptions)

func WithLockKey(key string) LeaderElectionOption {
	return func(o *leaderElectionOptions) {
		o.lockKey = key
	}
}

func WithLockTTL(ttl time.Duration) LeaderElectionOption {
	return func(o *leaderElectionOptions) {
		o.lockTTL = ttl
	}
}