diff --git a/cmd/derive/derive.go b/cmd/derive/derive.go new file mode 100644 index 0000000..dc45bac --- /dev/null +++ b/cmd/derive/derive.go @@ -0,0 +1,267 @@ +package derive + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + "io/fs" + "os" + + "github.com/spf13/cobra" + "src.lwithers.me.uk/go/rsa/pkg/pemfile" + "src.lwithers.me.uk/go/stdinprompt" + "src.lwithers.me.uk/go/writefile" +) + +var ( + public bool + pkcs1 bool + derOut bool +) + +// Register the "derive" subcommand. +func Register(root *cobra.Command) { + cmd := &cobra.Command{ + Use: "derive input.pem [output.pem]", + Short: "Derive public key from private, and/or change encoding", + Run: Derive, + Args: cobra.RangeArgs(1, 2), + } + + cmd.Flags().BoolVarP(&public, "public", "", false, "Only write public key part") + cmd.Flags().BoolVarP(&pkcs1, "pkcs1", "", false, "Write key as PKCS#1 rather than PKCS#8 / PKIX") + cmd.Flags().BoolVarP(&derOut, "der", "", false, "Write key as DER rather than PEM") + + root.AddCommand(cmd) +} + +// Derive a new form of output from the input file. Can derive public key from +// private, and/or change encoding. +func Derive(cmd *cobra.Command, args []string) { + if err := deriveAux(cmd, args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func deriveAux(cmd *cobra.Command, args []string) error { + // prepare output stream + var ( + out io.Writer = os.Stdout + commit = func() error { return nil } + ) + + if len(args) > 1 { + ff, f, err := writefile.New(args[1]) + if err != nil { + return err + } + out = f + commit = func() error { return writefile.Commit(ff, f) } + defer writefile.Abort(f) + } + + // prepare input stream + var ( + in io.Reader + infile string + ) + if args[0] == "-" { + in = stdinprompt.New() + infile = "(stdin)" + } else { + infile = args[0] + f, err := os.Open(infile) + if err != nil { + return err + } + defer f.Close() + in = f + } + + // read all input data to memory + raw, err := io.ReadAll(in) + if err != nil { + return err + } + + // parse PEM blocks from the input + var ( + didSomething bool + entryCount int + ) +parseLoop: + for len(raw) > 0 { + block, rest := pem.Decode(raw) + if block == nil { + break parseLoop + } + raw = rest + entryCount++ + + derivation := deriveNoop + switch block.Type { + case pemfile.TypePKCS1PrivateKey: + didSomething = true + derivation = derivePKCS1PrivateKey + + case pemfile.TypePKCS8PrivateKey: + didSomething = true + derivation = derivePKCS8PrivateKey + + case pemfile.TypePKCS1PublicKey: + didSomething = true + derivation = derivePKCS1PublicKey + + case pemfile.TypePKIXPublicKey: + didSomething = true + derivation = derivePKIXPublicKey + + case pemfile.TypeX509Certificate: + didSomething = true + derivation = deriveX509Certificate + } + + if err := derivation(out, block.Bytes); err != nil { + return fmt.Errorf("PEM block #%d (%s): %w", entryCount, block.Type, err) + } + } + + if !didSomething { + return &fs.PathError{ + Path: infile, + Op: "parse PEM", + Err: errors.New("no PEM-format RSA keys found"), + } + } + + return commit() +} + +func deriveNoop(io.Writer, []byte) error { + return nil +} + +func derivePKCS1PrivateKey(out io.Writer, der []byte) error { + key, err := x509.ParsePKCS1PrivateKey(der) + if err != nil { + return err + } + + if public { + return writePublic(out, &key.PublicKey) + } + return writePrivate(out, key) +} + +func derivePKCS8PrivateKey(out io.Writer, der []byte) error { + key, err := x509.ParsePKCS8PrivateKey(der) + if err != nil { + return err + } + + rsakey, ok := key.(*rsa.PrivateKey) + if !ok { + return errors.New("not an RSA key") + } + + if public { + return writePublic(out, &rsakey.PublicKey) + } + return writePrivate(out, rsakey) +} + +func derivePKCS1PublicKey(out io.Writer, der []byte) error { + key, err := x509.ParsePKCS1PublicKey(der) + if err != nil { + return err + } + return writePublic(out, key) +} + +func derivePKIXPublicKey(out io.Writer, der []byte) error { + key, err := x509.ParsePKIXPublicKey(der) + if err != nil { + return err + } + + rsakey, ok := key.(*rsa.PublicKey) + if !ok { + return errors.New("not an RSA key") + } + + return writePublic(out, rsakey) +} + +func deriveX509Certificate(out io.Writer, der []byte) error { + cert, err := x509.ParseCertificate(der) + if err != nil { + return err + } + + key, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return errors.New("not an RSA key") + } + + return writePublic(out, key) +} + +func writePrivate(out io.Writer, key *rsa.PrivateKey) error { + var ( + der []byte + pemType string + err error + ) + if pkcs1 { + der = x509.MarshalPKCS1PrivateKey(key) + pemType = pemfile.TypePKCS1PrivateKey + } else { + der, err = x509.MarshalPKCS8PrivateKey(key) + pemType = pemfile.TypePKCS8PrivateKey + } + if err != nil { + return err + } + + if derOut { + _, err = out.Write(der) + return err + } + + return pem.Encode(out, &pem.Block{ + Type: pemType, + Bytes: der, + }) +} + +func writePublic(out io.Writer, key *rsa.PublicKey) error { + var ( + der []byte + pemType string + err error + ) + if pkcs1 { + der = x509.MarshalPKCS1PublicKey(key) + pemType = pemfile.TypePKCS1PublicKey + } else { + der, err = x509.MarshalPKIXPublicKey(key) + pemType = pemfile.TypePKIXPublicKey + } + if err != nil { + return err + } + + if derOut { + _, err = out.Write(der) + return err + } + + return pem.Encode(out, &pem.Block{ + Type: pemType, + Bytes: der, + }) +} diff --git a/main.go b/main.go index 4a7f40a..2ed9c98 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/cobra" "src.lwithers.me.uk/go/rsa/cmd/ca" "src.lwithers.me.uk/go/rsa/cmd/csr" + "src.lwithers.me.uk/go/rsa/cmd/derive" "src.lwithers.me.uk/go/rsa/cmd/inspect" "src.lwithers.me.uk/go/rsa/cmd/keygen" ) @@ -20,6 +21,7 @@ func main() { keygen.Register(root) csr.Register(root) ca.Register(root) + derive.Register(root) if err := root.Execute(); err != nil { // error will already have been displayed