rsa/cmd/derive/derive.go

268 lines
5.0 KiB
Go
Raw Normal View History

2023-04-29 11:55:41 +01:00
package derive
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"io/fs"
"os"
"github.com/spf13/cobra"
"src.lwithers.me.uk/go/rsa/pkg/pemfile"
"src.lwithers.me.uk/go/stdinprompt"
"src.lwithers.me.uk/go/writefile"
)
var (
public bool
pkcs1 bool
derOut bool
)
// Register the "derive" subcommand.
func Register(root *cobra.Command) {
cmd := &cobra.Command{
Use: "derive input.pem [output.pem]",
Short: "Derive public key from private, and/or change encoding",
Run: Derive,
Args: cobra.RangeArgs(1, 2),
}
cmd.Flags().BoolVarP(&public, "public", "", false, "Only write public key part")
cmd.Flags().BoolVarP(&pkcs1, "pkcs1", "", false, "Write key as PKCS#1 rather than PKCS#8 / PKIX")
cmd.Flags().BoolVarP(&derOut, "der", "", false, "Write key as DER rather than PEM")
root.AddCommand(cmd)
}
// Derive a new form of output from the input file. Can derive public key from
// private, and/or change encoding.
func Derive(cmd *cobra.Command, args []string) {
if err := deriveAux(cmd, args); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func deriveAux(cmd *cobra.Command, args []string) error {
// prepare output stream
var (
out io.Writer = os.Stdout
commit = func() error { return nil }
)
if len(args) > 1 {
ff, f, err := writefile.New(args[1])
if err != nil {
return err
}
out = f
commit = func() error { return writefile.Commit(ff, f) }
defer writefile.Abort(f)
}
// prepare input stream
var (
in io.Reader
infile string
)
if args[0] == "-" {
in = stdinprompt.New()
infile = "(stdin)"
} else {
infile = args[0]
f, err := os.Open(infile)
if err != nil {
return err
}
defer f.Close()
in = f
}
// read all input data to memory
raw, err := io.ReadAll(in)
if err != nil {
return err
}
// parse PEM blocks from the input
var (
didSomething bool
entryCount int
)
parseLoop:
for len(raw) > 0 {
block, rest := pem.Decode(raw)
if block == nil {
break parseLoop
}
raw = rest
entryCount++
derivation := deriveNoop
switch block.Type {
case pemfile.TypePKCS1PrivateKey:
didSomething = true
derivation = derivePKCS1PrivateKey
case pemfile.TypePKCS8PrivateKey:
didSomething = true
derivation = derivePKCS8PrivateKey
case pemfile.TypePKCS1PublicKey:
didSomething = true
derivation = derivePKCS1PublicKey
case pemfile.TypePKIXPublicKey:
didSomething = true
derivation = derivePKIXPublicKey
case pemfile.TypeX509Certificate:
didSomething = true
derivation = deriveX509Certificate
}
if err := derivation(out, block.Bytes); err != nil {
return fmt.Errorf("PEM block #%d (%s): %w", entryCount, block.Type, err)
}
}
if !didSomething {
return &fs.PathError{
Path: infile,
Op: "parse PEM",
Err: errors.New("no PEM-format RSA keys found"),
}
}
return commit()
}
func deriveNoop(io.Writer, []byte) error {
return nil
}
func derivePKCS1PrivateKey(out io.Writer, der []byte) error {
key, err := x509.ParsePKCS1PrivateKey(der)
if err != nil {
return err
}
if public {
return writePublic(out, &key.PublicKey)
}
return writePrivate(out, key)
}
func derivePKCS8PrivateKey(out io.Writer, der []byte) error {
key, err := x509.ParsePKCS8PrivateKey(der)
if err != nil {
return err
}
rsakey, ok := key.(*rsa.PrivateKey)
if !ok {
return errors.New("not an RSA key")
}
if public {
return writePublic(out, &rsakey.PublicKey)
}
return writePrivate(out, rsakey)
}
func derivePKCS1PublicKey(out io.Writer, der []byte) error {
key, err := x509.ParsePKCS1PublicKey(der)
if err != nil {
return err
}
return writePublic(out, key)
}
func derivePKIXPublicKey(out io.Writer, der []byte) error {
key, err := x509.ParsePKIXPublicKey(der)
if err != nil {
return err
}
rsakey, ok := key.(*rsa.PublicKey)
if !ok {
return errors.New("not an RSA key")
}
return writePublic(out, rsakey)
}
func deriveX509Certificate(out io.Writer, der []byte) error {
cert, err := x509.ParseCertificate(der)
if err != nil {
return err
}
key, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return errors.New("not an RSA key")
}
return writePublic(out, key)
}
func writePrivate(out io.Writer, key *rsa.PrivateKey) error {
var (
der []byte
pemType string
err error
)
if pkcs1 {
der = x509.MarshalPKCS1PrivateKey(key)
pemType = pemfile.TypePKCS1PrivateKey
} else {
der, err = x509.MarshalPKCS8PrivateKey(key)
pemType = pemfile.TypePKCS8PrivateKey
}
if err != nil {
return err
}
if derOut {
_, err = out.Write(der)
return err
}
return pem.Encode(out, &pem.Block{
Type: pemType,
Bytes: der,
})
}
func writePublic(out io.Writer, key *rsa.PublicKey) error {
var (
der []byte
pemType string
err error
)
if pkcs1 {
der = x509.MarshalPKCS1PublicKey(key)
pemType = pemfile.TypePKCS1PublicKey
} else {
der, err = x509.MarshalPKIXPublicKey(key)
pemType = pemfile.TypePKIXPublicKey
}
if err != nil {
return err
}
if derOut {
_, err = out.Write(der)
return err
}
return pem.Encode(out, &pem.Block{
Type: pemType,
Bytes: der,
})
}