diff --git a/conn_test.go b/conn_test.go index 6c89a6a..b5f5a63 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,187 +1,124 @@ package journal import ( - "bytes" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "net" "path/filepath" - "strconv" "testing" + + "src.lwithers.me.uk/go/journal/testsink" ) -type testingCommon interface { - TempDir() string - Fatalf(string, ...any) -} - -// testConnector spawns a Conn with a local Unix datagram socket that checks -// incoming datagrams for well-formedness but otherwise discards them. -func testConnector(t testingCommon) *Conn { - sockPath := filepath.Join(t.TempDir(), "test-socket") - sock, err := net.ListenUnixgram("unixgram", &net.UnixAddr{ - Name: sockPath, - Net: "unixgram", - }) +// TestConn opens a connection to a test sink and writes one message, +// ensuring it is received. +func TestConn(t *testing.T) { + sockpath := filepath.Join(t.TempDir(), "socket") + sink, err := testsink.New(sockpath) if err != nil { - t.Fatalf("testConnector: ListenUnix: %v", err) - } - - go func() { - buf := make([]byte, 1<<16 /*enough for our tests */) - for { - n, err := sock.Read(buf) - switch err { - case nil: - case io.EOF: - return - default: - t.Fatalf("testConnector: Read: %v", err) - } - ok, pos, desc := checkWellFormedProto(buf[:n]) - if !ok { - t.Fatalf("received malformed data at pos 0x%x: %s\n%s", pos, desc, hex.Dump(buf[:n])) - } - } - }() - - conn, err := Connect(sockPath) - if err != nil { - t.Fatalf("testConnector: DialUnix: %v", err) - } - conn.ErrHandler = func(err error) { - t.Fatalf("testConnector: connection error: %v", err) - } - return conn -} - -func checkWellFormedProto(buf []byte) (ok bool, pos int, desc string) { - for pos < len(buf) { - // grab attribute name up to '=' or '\n' - off := bytes.IndexAny(buf[pos:], "=\n") - if off == -1 { - return false, pos, "unterminated key" - } - key := string(buf[pos : pos+off]) - if err := AttrKeyValid(key); err != nil { - return false, pos, err.Error() - } - pos += off - - // for KEY=VALUE, the value is terminated by a newline - if buf[pos] == '=' { - pos++ - off = bytes.IndexByte(buf[pos:], '\n') - if off == -1 { - return false, pos, "unterminated value" - } - pos += off // consume value - pos++ // consume trailing newline - continue - } - - // otherwise, expect an 8-bit little-endian length - pos++ // consume newline after key - if pos+8 > len(buf) { - return false, pos, "value length too short" - } - vlen := binary.LittleEndian.Uint64(buf[pos:]) - pos += 8 - if vlen > uint64(len(buf)) /* protect against overflow */ || - uint64(pos)+vlen+1 > uint64(len(buf)) { - return false, pos, "value length too long" - } - pos += int(vlen) - if buf[pos] != '\n' { - return false, pos, "value not terminated by newline" - } - pos++ - } - return true, pos, "" -} - -// TestConcurrentEntries is best run with the race detector, and tries to pick -// up any faults that might occur when concurrent goroutines write into the same -// Conn. -func TestConcurrentEntries(t *testing.T) { - c := testConnector(t) - const ( - numGoroutines = 16 - numEntries = 100 - ) - - // attributes which will be common to all EntryErr calls - attr := make([]Attr, 0, 10 /* enough capacity to avoid realloc on append; might trigger data races */) - attr = append(attr, Attr{ - Key: AttrKey{ - key: "HELLO", - }, - Value: []byte("world"), - }) - - // spawn goroutines - start := make(chan struct{}) - result := make(chan error, numGoroutines) - for i := range numGoroutines { - go func() { - var err error - <-start - for j := range numEntries { - err = c.EntryErr(PriDebug, "message "+strconv.Itoa(i)+"."+strconv.Itoa(j), attr) - if err != nil { - err = fmt.Errorf("message %d.%d error: %w", i, j, err) - break - } - } - result <- err - }() - } - - // try to get all the goroutines to start at roughly the same time - close(start) - - // collect results - var errs []error - for range numGoroutines { - if err := <-result; err != nil { - errs = append(errs, err) - } - } - if err := errors.Join(errs...); err != nil { t.Fatal(err) } -} + defer sink.Stop() -// BenchmarkEntry is a benchmark for the common Entry function. -func BenchmarkEntry(b *testing.B) { - c := testConnector(b) + conn, err := Connect(sockpath) + if err != nil { + t.Fatal(err) + } + defer conn.Close() - // add some common attributes - c.Common = make([]Attr, 0, 10) - c.Common = append(c.Common, Attr{ - Key: AttrKey{ - key: "COMMON_ATTR", - }, - Value: []byte("abc123\n"), - }) + conn.Entry(PriInfo, "hello, world", nil) - // attributes which will be common to all EntryErr calls - attr := make([]Attr, 0, 10) - attr = append(attr, Attr{ - Key: AttrKey{ - key: "HELLO", - }, - Value: []byte("world"), - }) + msg, err := sink.Message(0) + if err != nil { + t.Fatal(err) + } - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := c.EntryErr(PriDebug, "hello world!", attr) - if err != nil { - b.Fatalf("message %d: error %v", i, err) + msgText, attrs, err := msg.Decode() + if err != nil { + t.Error(err) + } + if msgText != "hello, world" { + t.Errorf("unexpected message text %q", msgText) + } + val, ok := testsink.GetAttr(attrs, "PRIORITY") + switch { + case !ok: + t.Error("missing PRIORITY attribute") + case val != "6": + t.Error("unexpected PRIORITY value") + } + + if t.Failed() { + for i := range attrs { + t.Errorf("attr %q=%q", attrs[i].Key, attrs[i].Val) + } + } +} + +// TestEntryBinary ensures that we can write a message with an attribute encoded +// as a binary field. +func TestEntryBinary(t *testing.T) { + sockpath := filepath.Join(t.TempDir(), "socket") + sink, err := testsink.New(sockpath) + if err != nil { + t.Fatal(err) + } + defer sink.Stop() + + conn, err := Connect(sockpath) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + expAttrs := []Attr{ + Attr{ + Key: MustAttrKey("SHORT"), + Value: []byte("short val"), + }, + Attr{ + Key: MustAttrKey("BINARY"), + Value: []byte("string with\n=embedded newline\nrequires binary protocol\n"), + }, + Attr{ + Key: MustAttrKey("LAST"), + Value: []byte("last\n"), + }, + } + conn.Entry(PriDebug, "hello, binary world", expAttrs) + + msg, err := sink.Message(0) + if err != nil { + t.Fatal(err) + } + + msgText, attrs, err := msg.Decode() + if err != nil { + t.Error(err) + } + if msgText != "hello, binary world" { + t.Errorf("unexpected message text %q", msgText) + } + + val, ok := testsink.GetAttr(attrs, "PRIORITY") + switch { + case !ok: + t.Error("missing PRIORITY attribute") + case val != "7": + t.Error("unexpected PRIORITY value") + } + + for i := range expAttrs { + key, expVal := expAttrs[i].Key.Key(), string(expAttrs[i].Value) + val, ok = testsink.GetAttr(attrs, key) + switch { + case !ok: + t.Errorf("missing %s attribute", key) + case val != expVal: + t.Errorf("unexpected %s value", key) + } + } + + if t.Failed() { + for i := range attrs { + t.Errorf("attr %q=%q", attrs[i].Key, attrs[i].Val) } } } diff --git a/testsink/message.go b/testsink/message.go new file mode 100644 index 0000000..99cae43 --- /dev/null +++ b/testsink/message.go @@ -0,0 +1,126 @@ +/* +Package testsink provides a partial implementation of the systemd-journald +Unix socket. Datagrams received on its socket are decoded and stored for unit +tests to examine. +*/ +package testsink + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" +) + +// Message returns the Nth message received. It waits for the message to arrive. +func (sink *Sink) Message(N int) (Message, error) { + sink.lock.Lock() + defer sink.lock.Unlock() + + for len(sink.msgs) <= N { + sink.mcond.Wait() + if sink.err != nil { + return Message{}, sink.err + } + } + return sink.msgs[N], nil +} + +// Message is recorded for each datagram received. +type Message struct { + Raw []byte +} + +type DecodedAttr struct { + Key, Val string +} + +func (m *Message) Decode() (msg string, attr []DecodedAttr, err error) { + raw := m.Raw + var errs []error + +DecodeLoop: + for len(raw) > 0 { + n, key := decodeAttrKey(raw) + raw = raw[n:] + switch { + case len(raw) == 0: + errs = append(errs, fmt.Errorf("unterminated attribute name %q", key)) + break DecodeLoop + case key == "": + errs = append(errs, errors.New("empty attribute name")) + } + + var val string + switch raw[0] { + case '=': + n, val = decodeAttrValText(raw[1:]) + case '\n': + var err error + n, val, err = decodeAttrValLen(raw[1:]) + if err != nil { + errs = append(errs, err) + } + } + raw = raw[1+n:] + if len(raw) == 0 { + errs = append(errs, fmt.Errorf("unterminated value for attribute %q", key)) + break DecodeLoop + } + + if raw[0] != '\n' { + errs = append(errs, errors.New("incorrectly terminated attribute value")) + } + raw = raw[1:] + + switch key { + case "MESSAGE": + msg = val + default: + attr = append(attr, DecodedAttr{Key: key, Val: val}) + } + } + + return msg, attr, errors.Join(errs...) +} + +func decodeAttrKey(raw []byte) (n int, key string) { + for i := range raw { + switch raw[i] { + case '\n', '=': + return i, string(raw[:i]) + } + } + return len(raw), string(raw) +} + +func decodeAttrValText(raw []byte) (n int, val string) { + term := bytes.IndexByte(raw, '\n') + if term == -1 { + term = len(raw) + } + return term, string(raw[:term]) +} + +func decodeAttrValLen(raw []byte) (n int, val string, err error) { + if len(raw) < 8 { + return len(raw), "", errors.New("not enough bytes for binary attribute value length") + } + amt := binary.LittleEndian.Uint64(raw) + raw = raw[8:] + if uint64(len(raw)) < amt { + return 8 + len(raw), string(raw), errors.New("not enough bytes for binary attribute value") + } + return int(8 + amt), string(raw[:amt]), nil +} + +// GetAttr returns the value of the attribute whose key name matches, and a +// boolean to indicate if it found a match. +func GetAttr(attr []DecodedAttr, key string) (value string, ok bool) { + for i := range attr { + if attr[i].Key == key { + return attr[i].Val, true + } + } + return "", false +} diff --git a/testsink/testsink.go b/testsink/testsink.go new file mode 100644 index 0000000..427a775 --- /dev/null +++ b/testsink/testsink.go @@ -0,0 +1,82 @@ +/* +Package testsink provides a partial implementation of the systemd-journald +Unix socket. Datagrams received on its socket are decoded and stored for unit +tests to examine. +*/ +package testsink + +import ( + "net" + "slices" + "sync" +) + +// Sink provides a Unix socket and captures messages sent to it using the +// systemd-journald wire protocol. +type Sink struct { + sock *net.UnixConn + stop chan struct{} + + lock sync.Mutex + mcond *sync.Cond + msgs []Message + err error +} + +// New returns a new Sink that is listening on the given path. +func New(sockpath string) (*Sink, error) { + sock, err := net.ListenUnixgram("unixgram", &net.UnixAddr{ + Name: sockpath, + Net: "unixgram", + }) + if err != nil { + return nil, err + } + sink := &Sink{ + sock: sock, + stop: make(chan struct{}, 1), + } + sink.mcond = sync.NewCond(&sink.lock) + go sink.stopper() + go sink.recv() + return sink, nil +} + +// Stop listening and close the socket. +func (sink *Sink) Stop() { + // non-blocking write onto channel; we only need to read from it once in + // order to stop the receiver, but using this rather than close ensures + // Stop() can be called multiple times without negative side effects + select { + case sink.stop <- struct{}{}: + default: + } +} + +func (sink *Sink) stopper() { + <-sink.stop + sink.sock.Close() +} + +func (sink *Sink) recv() { + buf := make([]byte, 131072) + for { + n, err := sink.sock.Read(buf) + + sink.lock.Lock() + if n > 0 { + sink.msgs = append(sink.msgs, Message{ + Raw: slices.Clone(buf[:n]), + }) + } + if err != nil { + sink.err = err + } + sink.lock.Unlock() + + sink.mcond.Broadcast() + if err != nil { + return + } + } +}