268 lines
5.0 KiB
Go
268 lines
5.0 KiB
Go
|
package derive
|
||
|
|
||
|
import (
|
||
|
"crypto/rsa"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/fs"
|
||
|
"os"
|
||
|
|
||
|
"github.com/spf13/cobra"
|
||
|
"src.lwithers.me.uk/go/rsa/pkg/pemfile"
|
||
|
"src.lwithers.me.uk/go/stdinprompt"
|
||
|
"src.lwithers.me.uk/go/writefile"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
public bool
|
||
|
pkcs1 bool
|
||
|
derOut bool
|
||
|
)
|
||
|
|
||
|
// Register the "derive" subcommand.
|
||
|
func Register(root *cobra.Command) {
|
||
|
cmd := &cobra.Command{
|
||
|
Use: "derive input.pem [output.pem]",
|
||
|
Short: "Derive public key from private, and/or change encoding",
|
||
|
Run: Derive,
|
||
|
Args: cobra.RangeArgs(1, 2),
|
||
|
}
|
||
|
|
||
|
cmd.Flags().BoolVarP(&public, "public", "", false, "Only write public key part")
|
||
|
cmd.Flags().BoolVarP(&pkcs1, "pkcs1", "", false, "Write key as PKCS#1 rather than PKCS#8 / PKIX")
|
||
|
cmd.Flags().BoolVarP(&derOut, "der", "", false, "Write key as DER rather than PEM")
|
||
|
|
||
|
root.AddCommand(cmd)
|
||
|
}
|
||
|
|
||
|
// Derive a new form of output from the input file. Can derive public key from
|
||
|
// private, and/or change encoding.
|
||
|
func Derive(cmd *cobra.Command, args []string) {
|
||
|
if err := deriveAux(cmd, args); err != nil {
|
||
|
fmt.Fprintln(os.Stderr, err)
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func deriveAux(cmd *cobra.Command, args []string) error {
|
||
|
// prepare output stream
|
||
|
var (
|
||
|
out io.Writer = os.Stdout
|
||
|
commit = func() error { return nil }
|
||
|
)
|
||
|
|
||
|
if len(args) > 1 {
|
||
|
ff, f, err := writefile.New(args[1])
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
out = f
|
||
|
commit = func() error { return writefile.Commit(ff, f) }
|
||
|
defer writefile.Abort(f)
|
||
|
}
|
||
|
|
||
|
// prepare input stream
|
||
|
var (
|
||
|
in io.Reader
|
||
|
infile string
|
||
|
)
|
||
|
if args[0] == "-" {
|
||
|
in = stdinprompt.New()
|
||
|
infile = "(stdin)"
|
||
|
} else {
|
||
|
infile = args[0]
|
||
|
f, err := os.Open(infile)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer f.Close()
|
||
|
in = f
|
||
|
}
|
||
|
|
||
|
// read all input data to memory
|
||
|
raw, err := io.ReadAll(in)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// parse PEM blocks from the input
|
||
|
var (
|
||
|
didSomething bool
|
||
|
entryCount int
|
||
|
)
|
||
|
parseLoop:
|
||
|
for len(raw) > 0 {
|
||
|
block, rest := pem.Decode(raw)
|
||
|
if block == nil {
|
||
|
break parseLoop
|
||
|
}
|
||
|
raw = rest
|
||
|
entryCount++
|
||
|
|
||
|
derivation := deriveNoop
|
||
|
switch block.Type {
|
||
|
case pemfile.TypePKCS1PrivateKey:
|
||
|
didSomething = true
|
||
|
derivation = derivePKCS1PrivateKey
|
||
|
|
||
|
case pemfile.TypePKCS8PrivateKey:
|
||
|
didSomething = true
|
||
|
derivation = derivePKCS8PrivateKey
|
||
|
|
||
|
case pemfile.TypePKCS1PublicKey:
|
||
|
didSomething = true
|
||
|
derivation = derivePKCS1PublicKey
|
||
|
|
||
|
case pemfile.TypePKIXPublicKey:
|
||
|
didSomething = true
|
||
|
derivation = derivePKIXPublicKey
|
||
|
|
||
|
case pemfile.TypeX509Certificate:
|
||
|
didSomething = true
|
||
|
derivation = deriveX509Certificate
|
||
|
}
|
||
|
|
||
|
if err := derivation(out, block.Bytes); err != nil {
|
||
|
return fmt.Errorf("PEM block #%d (%s): %w", entryCount, block.Type, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if !didSomething {
|
||
|
return &fs.PathError{
|
||
|
Path: infile,
|
||
|
Op: "parse PEM",
|
||
|
Err: errors.New("no PEM-format RSA keys found"),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return commit()
|
||
|
}
|
||
|
|
||
|
func deriveNoop(io.Writer, []byte) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func derivePKCS1PrivateKey(out io.Writer, der []byte) error {
|
||
|
key, err := x509.ParsePKCS1PrivateKey(der)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if public {
|
||
|
return writePublic(out, &key.PublicKey)
|
||
|
}
|
||
|
return writePrivate(out, key)
|
||
|
}
|
||
|
|
||
|
func derivePKCS8PrivateKey(out io.Writer, der []byte) error {
|
||
|
key, err := x509.ParsePKCS8PrivateKey(der)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
rsakey, ok := key.(*rsa.PrivateKey)
|
||
|
if !ok {
|
||
|
return errors.New("not an RSA key")
|
||
|
}
|
||
|
|
||
|
if public {
|
||
|
return writePublic(out, &rsakey.PublicKey)
|
||
|
}
|
||
|
return writePrivate(out, rsakey)
|
||
|
}
|
||
|
|
||
|
func derivePKCS1PublicKey(out io.Writer, der []byte) error {
|
||
|
key, err := x509.ParsePKCS1PublicKey(der)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return writePublic(out, key)
|
||
|
}
|
||
|
|
||
|
func derivePKIXPublicKey(out io.Writer, der []byte) error {
|
||
|
key, err := x509.ParsePKIXPublicKey(der)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
rsakey, ok := key.(*rsa.PublicKey)
|
||
|
if !ok {
|
||
|
return errors.New("not an RSA key")
|
||
|
}
|
||
|
|
||
|
return writePublic(out, rsakey)
|
||
|
}
|
||
|
|
||
|
func deriveX509Certificate(out io.Writer, der []byte) error {
|
||
|
cert, err := x509.ParseCertificate(der)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
key, ok := cert.PublicKey.(*rsa.PublicKey)
|
||
|
if !ok {
|
||
|
return errors.New("not an RSA key")
|
||
|
}
|
||
|
|
||
|
return writePublic(out, key)
|
||
|
}
|
||
|
|
||
|
func writePrivate(out io.Writer, key *rsa.PrivateKey) error {
|
||
|
var (
|
||
|
der []byte
|
||
|
pemType string
|
||
|
err error
|
||
|
)
|
||
|
if pkcs1 {
|
||
|
der = x509.MarshalPKCS1PrivateKey(key)
|
||
|
pemType = pemfile.TypePKCS1PrivateKey
|
||
|
} else {
|
||
|
der, err = x509.MarshalPKCS8PrivateKey(key)
|
||
|
pemType = pemfile.TypePKCS8PrivateKey
|
||
|
}
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if derOut {
|
||
|
_, err = out.Write(der)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return pem.Encode(out, &pem.Block{
|
||
|
Type: pemType,
|
||
|
Bytes: der,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func writePublic(out io.Writer, key *rsa.PublicKey) error {
|
||
|
var (
|
||
|
der []byte
|
||
|
pemType string
|
||
|
err error
|
||
|
)
|
||
|
if pkcs1 {
|
||
|
der = x509.MarshalPKCS1PublicKey(key)
|
||
|
pemType = pemfile.TypePKCS1PublicKey
|
||
|
} else {
|
||
|
der, err = x509.MarshalPKIXPublicKey(key)
|
||
|
pemType = pemfile.TypePKIXPublicKey
|
||
|
}
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if derOut {
|
||
|
_, err = out.Write(der)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return pem.Encode(out, &pem.Block{
|
||
|
Type: pemType,
|
||
|
Bytes: der,
|
||
|
})
|
||
|
}
|