package derive import ( "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "io" "io/fs" "os" "github.com/spf13/cobra" "src.lwithers.me.uk/go/rsa/pkg/pemfile" "src.lwithers.me.uk/go/stdinprompt" "src.lwithers.me.uk/go/writefile" ) var ( public bool pkcs1 bool derOut bool ) // Register the "derive" subcommand. func Register(root *cobra.Command) { cmd := &cobra.Command{ Use: "derive input.pem [output.pem]", Short: "Derive public key from private, and/or change encoding", Run: Derive, Args: cobra.RangeArgs(1, 2), } cmd.Flags().BoolVarP(&public, "public", "", false, "Only write public key part") cmd.Flags().BoolVarP(&pkcs1, "pkcs1", "", false, "Write key as PKCS#1 rather than PKCS#8 / PKIX") cmd.Flags().BoolVarP(&derOut, "der", "", false, "Write key as DER rather than PEM") root.AddCommand(cmd) } // Derive a new form of output from the input file. Can derive public key from // private, and/or change encoding. func Derive(cmd *cobra.Command, args []string) { if err := deriveAux(cmd, args); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } func deriveAux(cmd *cobra.Command, args []string) error { // prepare output stream var ( out io.Writer = os.Stdout commit = func() error { return nil } ) if len(args) > 1 { ff, f, err := writefile.New(args[1]) if err != nil { return err } out = f commit = func() error { return writefile.Commit(ff, f) } defer writefile.Abort(f) } // prepare input stream var ( in io.Reader infile string ) if args[0] == "-" { in = stdinprompt.New() infile = "(stdin)" } else { infile = args[0] f, err := os.Open(infile) if err != nil { return err } defer f.Close() in = f } // read all input data to memory raw, err := io.ReadAll(in) if err != nil { return err } // parse PEM blocks from the input var ( didSomething bool entryCount int ) parseLoop: for len(raw) > 0 { block, rest := pem.Decode(raw) if block == nil { break parseLoop } raw = rest entryCount++ derivation := deriveNoop switch block.Type { case pemfile.TypePKCS1PrivateKey: didSomething = true derivation = derivePKCS1PrivateKey case pemfile.TypePKCS8PrivateKey: didSomething = true derivation = derivePKCS8PrivateKey case pemfile.TypePKCS1PublicKey: didSomething = true derivation = derivePKCS1PublicKey case pemfile.TypePKIXPublicKey: didSomething = true derivation = derivePKIXPublicKey case pemfile.TypeX509Certificate: didSomething = true derivation = deriveX509Certificate } if err := derivation(out, block.Bytes); err != nil { return fmt.Errorf("PEM block #%d (%s): %w", entryCount, block.Type, err) } } if !didSomething { return &fs.PathError{ Path: infile, Op: "parse PEM", Err: errors.New("no PEM-format RSA keys found"), } } return commit() } func deriveNoop(io.Writer, []byte) error { return nil } func derivePKCS1PrivateKey(out io.Writer, der []byte) error { key, err := x509.ParsePKCS1PrivateKey(der) if err != nil { return err } if public { return writePublic(out, &key.PublicKey) } return writePrivate(out, key) } func derivePKCS8PrivateKey(out io.Writer, der []byte) error { key, err := x509.ParsePKCS8PrivateKey(der) if err != nil { return err } rsakey, ok := key.(*rsa.PrivateKey) if !ok { return errors.New("not an RSA key") } if public { return writePublic(out, &rsakey.PublicKey) } return writePrivate(out, rsakey) } func derivePKCS1PublicKey(out io.Writer, der []byte) error { key, err := x509.ParsePKCS1PublicKey(der) if err != nil { return err } return writePublic(out, key) } func derivePKIXPublicKey(out io.Writer, der []byte) error { key, err := x509.ParsePKIXPublicKey(der) if err != nil { return err } rsakey, ok := key.(*rsa.PublicKey) if !ok { return errors.New("not an RSA key") } return writePublic(out, rsakey) } func deriveX509Certificate(out io.Writer, der []byte) error { cert, err := x509.ParseCertificate(der) if err != nil { return err } key, ok := cert.PublicKey.(*rsa.PublicKey) if !ok { return errors.New("not an RSA key") } return writePublic(out, key) } func writePrivate(out io.Writer, key *rsa.PrivateKey) error { var ( der []byte pemType string err error ) if pkcs1 { der = x509.MarshalPKCS1PrivateKey(key) pemType = pemfile.TypePKCS1PrivateKey } else { der, err = x509.MarshalPKCS8PrivateKey(key) pemType = pemfile.TypePKCS8PrivateKey } if err != nil { return err } if derOut { _, err = out.Write(der) return err } return pem.Encode(out, &pem.Block{ Type: pemType, Bytes: der, }) } func writePublic(out io.Writer, key *rsa.PublicKey) error { var ( der []byte pemType string err error ) if pkcs1 { der = x509.MarshalPKCS1PublicKey(key) pemType = pemfile.TypePKCS1PublicKey } else { der, err = x509.MarshalPKIXPublicKey(key) pemType = pemfile.TypePKIXPublicKey } if err != nil { return err } if derOut { _, err = out.Write(der) return err } return pem.Encode(out, &pem.Block{ Type: pemType, Bytes: der, }) }