Refactor some JetStream helper code, add support for specifying JetStream domain

Signed-off-by: Neil Alexander <git@neilalexander.dev>
This commit is contained in:
Neil Alexander 2024-12-31 21:32:21 +00:00
parent add73ec866
commit a9824e78d0
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 132 additions and 112 deletions

View file

@ -15,6 +15,8 @@ type JetStream struct {
// The prefix to use for stream names for this homeserver - really only
// useful if running more than one Dendrite on the same NATS deployment.
TopicPrefix string `yaml:"topic_prefix"`
// The JetStream domain, if needed.
JetStreamDomain string `yaml:"js_domain"`
// Keep all storage in memory. This is mostly useful for unit tests.
InMemory bool `yaml:"in_memory"`
// Disable logging. This is mostly useful for unit tests.

View file

@ -22,7 +22,7 @@ func JetStreamConsumer(
f func(ctx context.Context, msgs []*nats.Msg) bool,
opts ...nats.SubOpt,
) error {
defer func() {
defer func(durable string) {
// If there are existing consumers from before they were pull
// consumers, we need to clean up the old push consumers. However,
// in order to not affect the interest-based policies, we need to
@ -33,23 +33,28 @@ func JetStreamConsumer(
logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable)
}
}
}()
}(durable)
name := durable + "Pull"
sub, err := js.PullSubscribe(subj, name, opts...)
durable = durable + "Pull"
sub, err := js.PullSubscribe(subj, durable, opts...)
if err != nil {
sentry.CaptureException(err)
return fmt.Errorf("nats.SubscribeSync: %w", err)
logrus.WithContext(ctx).WithError(err).Warnf("Failed to configure durable %q", durable)
return err
}
go func() {
go jetStreamConsumerWorker(ctx, sub, subj, batch, f)
return nil
}
func jetStreamConsumerWorker(
ctx context.Context, sub *nats.Subscription, subj string, batch int,
f func(ctx context.Context, msgs []*nats.Msg) bool,
) {
for {
// If the parent context has given up then there's no point in
// carrying on doing anything, so stop the listener.
select {
case <-ctx.Done():
if err := sub.Unsubscribe(); err != nil {
logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable)
}
return
default:
}
@ -73,19 +78,23 @@ func JetStreamConsumer(
// just timed out and we should try again.
continue
}
} else if errors.Is(err, nats.ErrConsumerDeleted) {
// The consumer was deleted so stop.
return
} else if errors.Is(err, nats.ErrTimeout) {
// Pull request was invalidated, try again.
continue
} else if errors.Is(err, nats.ErrConsumerLeadershipChanged) {
// Leadership changed so pending pull requests became invalidated,
// just try again.
continue
} else if err.Error() == "nats: Server Shutdown" {
// The server is shutting down, but we'll rely on reconnect
// behaviour to try and either connect us to another node (if
// clustered) or to reconnect when the server comes back up.
continue
} else {
// Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string
if err.Error() == "nats: Server Shutdown" {
logrus.WithContext(ctx).Warn("nats server shutting down")
// Something else went wrong.
logrus.WithContext(ctx).WithField("subject", subj).WithError(err).Warn("Error on pull subscriber fetch")
return
}
// Something else went wrong, so we'll panic.
sentry.CaptureException(err)
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
}
}
if len(msgs) < 1 {
continue
@ -113,6 +122,4 @@ func JetStreamConsumer(
}
}
}
}()
return nil
}

View file

@ -8,7 +8,6 @@ import (
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/element-hq/dendrite/setup/config"
@ -36,17 +35,20 @@ func DeleteAllStreams(js natsclient.JetStreamContext, cfg *config.JetStream) {
func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
natsLock.Lock()
defer natsLock.Unlock()
// check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 {
// reuse existing connections
if s.nc != nil {
var err error
// If an existing connection exists, return it.
if s.nc != nil && s.js != nil {
return s.js, s.nc
}
// For connecting to an external NATS server.
if len(cfg.Addresses) > 0 {
s.js, s.nc = setupNATS(process, cfg, nil)
return s.js, s.nc
}
if s.Server == nil {
var err error
if len(cfg.Addresses) == 0 && s.Server == nil {
opts := &natsserver.Options{
ServerName: "monolith",
DontListen: true,
@ -58,8 +60,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
NoLog: cfg.NoLog,
SyncAlways: true,
}
s.Server, err = natsserver.NewServer(opts)
if err != nil {
if s.Server, err = natsserver.NewServer(opts); err != nil {
panic(err)
}
if !cfg.NoLog {
@ -75,37 +76,47 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
s.WaitForShutdown()
process.ComponentFinished()
}()
}
if !s.ReadyForConnections(time.Second * 60) {
logrus.Fatalln("NATS did not start in time")
}
// reuse existing connections
if s.nc != nil {
return s.js, s.nc
}
nc, err := natsclient.Connect("", natsclient.InProcessServer(s))
if err != nil {
// No existing process connection, create a new one.
if s.nc, err = natsclient.Connect("", natsclient.InProcessServer(s.Server)); err != nil {
logrus.Fatalln("Failed to create NATS client")
}
js, _ := setupNATS(process, cfg, nc)
s.js = js
s.nc = nc
return js, nc
s.js, s.nc = setupNATS(process, cfg, s.nc)
return s.js, s.nc
}
// nolint:gocyclo
func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStreamContext, *natsclient.Conn) {
jsOpts := []natsclient.JSOpt{}
if cfg.JetStreamDomain != "" {
jsOpts = append(jsOpts, natsclient.Domain(cfg.JetStreamDomain))
}
if nc == nil {
var err error
opts := []natsclient.Option{}
opts := []natsclient.Option{
natsclient.Name("Dendrite"),
natsclient.MaxReconnects(-1), // Try forever
natsclient.ReconnectJitter(time.Second, time.Second),
natsclient.ReconnectWait(time.Second * 10),
natsclient.ReconnectHandler(func(c *natsclient.Conn) {
js, jerr := c.JetStream(jsOpts...)
if jerr != nil {
logrus.WithError(jerr).Panic("Unable to get JetStream context in reconnect handler")
return
}
checkAndConfigureStreams(process, cfg, js)
}),
}
if cfg.DisableTLSValidation {
opts = append(opts, natsclient.Secure(&tls.Config{
InsecureSkipVerify: true,
}))
}
if string(cfg.Credentials) != "" {
opts = append(opts, natsclient.UserCredentials(string(cfg.Credentials)))
}
nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","), opts...)
if err != nil {
logrus.WithError(err).Panic("Unable to connect to NATS")
@ -113,15 +124,19 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
}
}
s, err := nc.JetStream()
js, err := nc.JetStream(jsOpts...)
if err != nil {
logrus.WithError(err).Panic("Unable to get JetStream context")
return nil, nil
}
checkAndConfigureStreams(process, cfg, js)
return js, nc
}
func checkAndConfigureStreams(process *process.ProcessContext, cfg *config.JetStream, js natsclient.JetStreamContext) {
for _, stream := range streams { // streams are defined in streams.go
name := cfg.Prefixed(stream.Name)
info, err := s.StreamInfo(name)
info, err := js.StreamInfo(name)
if err != nil && err != natsclient.ErrStreamNotFound {
logrus.WithError(err).Fatal("Unable to get stream info")
}
@ -153,11 +168,11 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
case info.Config.MaxAge != stream.MaxAge:
// Try updating the stream first, as many things can be updated
// non-destructively.
if info, err = s.UpdateStream(stream); err != nil {
if info, err = js.UpdateStream(stream); err != nil {
logrus.WithError(err).Warnf("Unable to update stream %q, recreating...", name)
// We failed to update the stream, this is a last attempt to get
// things working but may result in data loss.
if err = s.DeleteStream(name); err != nil {
if err = js.DeleteStream(name); err != nil {
logrus.WithError(err).Fatalf("Unable to delete stream %q", name)
}
info = nil
@ -176,7 +191,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
namespaced := *stream
namespaced.Name = name
namespaced.Subjects = subjects
if _, err = s.AddStream(&namespaced); err != nil {
if _, err = js.AddStream(&namespaced); err != nil {
logger := logrus.WithError(err).WithFields(logrus.Fields{
"stream": namespaced.Name,
"subjects": namespaced.Subjects,
@ -193,10 +208,9 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
// we can't recover anything that was queued on the disk but we
// will still be able to start and run hopefully in the meantime.
logger.WithError(err).Error("Unable to add stream")
sentry.CaptureException(fmt.Errorf("Unable to add stream %q: %w", namespaced.Name, err))
namespaced.Storage = natsclient.MemoryStorage
if _, err = s.AddStream(&namespaced); err != nil {
if _, err = js.AddStream(&namespaced); err != nil {
// We tried to add the stream in-memory instead but something
// went wrong. That's an unrecoverable situation so we will
// give up at this point.
@ -208,7 +222,6 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
// disk will be left alone, but our ability to recover from a
// future crash will be limited. Yell about it.
err := fmt.Errorf("Stream %q is running in-memory; this may be due to data corruption in the JetStream storage directory", namespaced.Name)
sentry.CaptureException(err)
process.Degraded(err)
}
}
@ -229,15 +242,13 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers {
consumerName := cfg.Matrix.JetStream.Prefixed(consumer) + "Pull"
consumerInfo, err := s.ConsumerInfo(streamName, consumerName)
consumerInfo, err := js.ConsumerInfo(streamName, consumerName)
if err != nil || consumerInfo == nil {
continue
}
if err = s.DeleteConsumer(streamName, consumerName); err != nil {
if err = js.DeleteConsumer(streamName, consumerName); err != nil {
logrus.WithError(err).Errorf("Unable to clean up old consumer %q for stream %q", consumer, stream)
}
}
}
return s, nc
}