parent
befed9c20c
commit
216f0477b5
@ -0,0 +1,615 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/* |
||||||
|
Package agent implements a client to an ssh-agent daemon. |
||||||
|
|
||||||
|
References: |
||||||
|
[PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD
|
||||||
|
*/ |
||||||
|
package agent // import "golang.org/x/crypto/ssh/agent"
|
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/dsa" |
||||||
|
"crypto/ecdsa" |
||||||
|
"crypto/elliptic" |
||||||
|
"crypto/rsa" |
||||||
|
"encoding/base64" |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"math/big" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
// Agent represents the capabilities of an ssh-agent.
|
||||||
|
type Agent interface { |
||||||
|
// List returns the identities known to the agent.
|
||||||
|
List() ([]*Key, error) |
||||||
|
|
||||||
|
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||||
|
// in [PROTOCOL.agent] section 2.6.2.
|
||||||
|
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) |
||||||
|
|
||||||
|
// Add adds a private key to the agent.
|
||||||
|
Add(key AddedKey) error |
||||||
|
|
||||||
|
// Remove removes all identities with the given public key.
|
||||||
|
Remove(key ssh.PublicKey) error |
||||||
|
|
||||||
|
// RemoveAll removes all identities.
|
||||||
|
RemoveAll() error |
||||||
|
|
||||||
|
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||||
|
Lock(passphrase []byte) error |
||||||
|
|
||||||
|
// Unlock undoes the effect of Lock
|
||||||
|
Unlock(passphrase []byte) error |
||||||
|
|
||||||
|
// Signers returns signers for all the known keys.
|
||||||
|
Signers() ([]ssh.Signer, error) |
||||||
|
} |
||||||
|
|
||||||
|
// AddedKey describes an SSH key to be added to an Agent.
|
||||||
|
type AddedKey struct { |
||||||
|
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
|
||||||
|
// *ecdsa.PrivateKey, which will be inserted into the agent.
|
||||||
|
PrivateKey interface{} |
||||||
|
// Certificate, if not nil, is communicated to the agent and will be
|
||||||
|
// stored with the key.
|
||||||
|
Certificate *ssh.Certificate |
||||||
|
// Comment is an optional, free-form string.
|
||||||
|
Comment string |
||||||
|
// LifetimeSecs, if not zero, is the number of seconds that the
|
||||||
|
// agent will store the key for.
|
||||||
|
LifetimeSecs uint32 |
||||||
|
// ConfirmBeforeUse, if true, requests that the agent confirm with the
|
||||||
|
// user before each use of this key.
|
||||||
|
ConfirmBeforeUse bool |
||||||
|
} |
||||||
|
|
||||||
|
// See [PROTOCOL.agent], section 3.
|
||||||
|
const ( |
||||||
|
agentRequestV1Identities = 1 |
||||||
|
|
||||||
|
// 3.2 Requests from client to agent for protocol 2 key operations
|
||||||
|
agentAddIdentity = 17 |
||||||
|
agentRemoveIdentity = 18 |
||||||
|
agentRemoveAllIdentities = 19 |
||||||
|
agentAddIdConstrained = 25 |
||||||
|
|
||||||
|
// 3.3 Key-type independent requests from client to agent
|
||||||
|
agentAddSmartcardKey = 20 |
||||||
|
agentRemoveSmartcardKey = 21 |
||||||
|
agentLock = 22 |
||||||
|
agentUnlock = 23 |
||||||
|
agentAddSmartcardKeyConstrained = 26 |
||||||
|
|
||||||
|
// 3.7 Key constraint identifiers
|
||||||
|
agentConstrainLifetime = 1 |
||||||
|
agentConstrainConfirm = 2 |
||||||
|
) |
||||||
|
|
||||||
|
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
||||||
|
// is a sanity check, not a limit in the spec.
|
||||||
|
const maxAgentResponseBytes = 16 << 20 |
||||||
|
|
||||||
|
// Agent messages:
|
||||||
|
// These structures mirror the wire format of the corresponding ssh agent
|
||||||
|
// messages found in [PROTOCOL.agent].
|
||||||
|
|
||||||
|
// 3.4 Generic replies from agent to client
|
||||||
|
const agentFailure = 5 |
||||||
|
|
||||||
|
type failureAgentMsg struct{} |
||||||
|
|
||||||
|
const agentSuccess = 6 |
||||||
|
|
||||||
|
type successAgentMsg struct{} |
||||||
|
|
||||||
|
// See [PROTOCOL.agent], section 2.5.2.
|
||||||
|
const agentRequestIdentities = 11 |
||||||
|
|
||||||
|
type requestIdentitiesAgentMsg struct{} |
||||||
|
|
||||||
|
// See [PROTOCOL.agent], section 2.5.2.
|
||||||
|
const agentIdentitiesAnswer = 12 |
||||||
|
|
||||||
|
type identitiesAnswerAgentMsg struct { |
||||||
|
NumKeys uint32 `sshtype:"12"` |
||||||
|
Keys []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See [PROTOCOL.agent], section 2.6.2.
|
||||||
|
const agentSignRequest = 13 |
||||||
|
|
||||||
|
type signRequestAgentMsg struct { |
||||||
|
KeyBlob []byte `sshtype:"13"` |
||||||
|
Data []byte |
||||||
|
Flags uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// See [PROTOCOL.agent], section 2.6.2.
|
||||||
|
|
||||||
|
// 3.6 Replies from agent to client for protocol 2 key operations
|
||||||
|
const agentSignResponse = 14 |
||||||
|
|
||||||
|
type signResponseAgentMsg struct { |
||||||
|
SigBlob []byte `sshtype:"14"` |
||||||
|
} |
||||||
|
|
||||||
|
type publicKey struct { |
||||||
|
Format string |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// Key represents a protocol 2 public key as defined in
|
||||||
|
// [PROTOCOL.agent], section 2.5.2.
|
||||||
|
type Key struct { |
||||||
|
Format string |
||||||
|
Blob []byte |
||||||
|
Comment string |
||||||
|
} |
||||||
|
|
||||||
|
func clientErr(err error) error { |
||||||
|
return fmt.Errorf("agent: client error: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// String returns the storage form of an agent key with the format, base64
|
||||||
|
// encoded serialized key, and the comment if it is not empty.
|
||||||
|
func (k *Key) String() string { |
||||||
|
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) |
||||||
|
|
||||||
|
if k.Comment != "" { |
||||||
|
s += " " + k.Comment |
||||||
|
} |
||||||
|
|
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
// Type returns the public key type.
|
||||||
|
func (k *Key) Type() string { |
||||||
|
return k.Format |
||||||
|
} |
||||||
|
|
||||||
|
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
|
||||||
|
func (k *Key) Marshal() []byte { |
||||||
|
return k.Blob |
||||||
|
} |
||||||
|
|
||||||
|
// Verify satisfies the ssh.PublicKey interface, but is not
|
||||||
|
// implemented for agent keys.
|
||||||
|
func (k *Key) Verify(data []byte, sig *ssh.Signature) error { |
||||||
|
return errors.New("agent: agent key does not know how to verify") |
||||||
|
} |
||||||
|
|
||||||
|
type wireKey struct { |
||||||
|
Format string |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
func parseKey(in []byte) (out *Key, rest []byte, err error) { |
||||||
|
var record struct { |
||||||
|
Blob []byte |
||||||
|
Comment string |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
if err := ssh.Unmarshal(in, &record); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var wk wireKey |
||||||
|
if err := ssh.Unmarshal(record.Blob, &wk); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &Key{ |
||||||
|
Format: wk.Format, |
||||||
|
Blob: record.Blob, |
||||||
|
Comment: record.Comment, |
||||||
|
}, record.Rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
// client is a client for an ssh-agent process.
|
||||||
|
type client struct { |
||||||
|
// conn is typically a *net.UnixConn
|
||||||
|
conn io.ReadWriter |
||||||
|
// mu is used to prevent concurrent access to the agent
|
||||||
|
mu sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
// NewClient returns an Agent that talks to an ssh-agent process over
|
||||||
|
// the given connection.
|
||||||
|
func NewClient(rw io.ReadWriter) Agent { |
||||||
|
return &client{conn: rw} |
||||||
|
} |
||||||
|
|
||||||
|
// call sends an RPC to the agent. On success, the reply is
|
||||||
|
// unmarshaled into reply and replyType is set to the first byte of
|
||||||
|
// the reply, which contains the type of the message.
|
||||||
|
func (c *client) call(req []byte) (reply interface{}, err error) { |
||||||
|
c.mu.Lock() |
||||||
|
defer c.mu.Unlock() |
||||||
|
|
||||||
|
msg := make([]byte, 4+len(req)) |
||||||
|
binary.BigEndian.PutUint32(msg, uint32(len(req))) |
||||||
|
copy(msg[4:], req) |
||||||
|
if _, err = c.conn.Write(msg); err != nil { |
||||||
|
return nil, clientErr(err) |
||||||
|
} |
||||||
|
|
||||||
|
var respSizeBuf [4]byte |
||||||
|
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { |
||||||
|
return nil, clientErr(err) |
||||||
|
} |
||||||
|
respSize := binary.BigEndian.Uint32(respSizeBuf[:]) |
||||||
|
if respSize > maxAgentResponseBytes { |
||||||
|
return nil, clientErr(err) |
||||||
|
} |
||||||
|
|
||||||
|
buf := make([]byte, respSize) |
||||||
|
if _, err = io.ReadFull(c.conn, buf); err != nil { |
||||||
|
return nil, clientErr(err) |
||||||
|
} |
||||||
|
reply, err = unmarshal(buf) |
||||||
|
if err != nil { |
||||||
|
return nil, clientErr(err) |
||||||
|
} |
||||||
|
return reply, err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) simpleCall(req []byte) error { |
||||||
|
resp, err := c.call(req) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, ok := resp.(*successAgentMsg); ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return errors.New("agent: failure") |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) RemoveAll() error { |
||||||
|
return c.simpleCall([]byte{agentRemoveAllIdentities}) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) Remove(key ssh.PublicKey) error { |
||||||
|
req := ssh.Marshal(&agentRemoveIdentityMsg{ |
||||||
|
KeyBlob: key.Marshal(), |
||||||
|
}) |
||||||
|
return c.simpleCall(req) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) Lock(passphrase []byte) error { |
||||||
|
req := ssh.Marshal(&agentLockMsg{ |
||||||
|
Passphrase: passphrase, |
||||||
|
}) |
||||||
|
return c.simpleCall(req) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) Unlock(passphrase []byte) error { |
||||||
|
req := ssh.Marshal(&agentUnlockMsg{ |
||||||
|
Passphrase: passphrase, |
||||||
|
}) |
||||||
|
return c.simpleCall(req) |
||||||
|
} |
||||||
|
|
||||||
|
// List returns the identities known to the agent.
|
||||||
|
func (c *client) List() ([]*Key, error) { |
||||||
|
// see [PROTOCOL.agent] section 2.5.2.
|
||||||
|
req := []byte{agentRequestIdentities} |
||||||
|
|
||||||
|
msg, err := c.call(req) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
switch msg := msg.(type) { |
||||||
|
case *identitiesAnswerAgentMsg: |
||||||
|
if msg.NumKeys > maxAgentResponseBytes/8 { |
||||||
|
return nil, errors.New("agent: too many keys in agent reply") |
||||||
|
} |
||||||
|
keys := make([]*Key, msg.NumKeys) |
||||||
|
data := msg.Keys |
||||||
|
for i := uint32(0); i < msg.NumKeys; i++ { |
||||||
|
var key *Key |
||||||
|
var err error |
||||||
|
if key, data, err = parseKey(data); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
keys[i] = key |
||||||
|
} |
||||||
|
return keys, nil |
||||||
|
case *failureAgentMsg: |
||||||
|
return nil, errors.New("agent: failed to list keys") |
||||||
|
} |
||||||
|
panic("unreachable") |
||||||
|
} |
||||||
|
|
||||||
|
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||||
|
// in [PROTOCOL.agent] section 2.6.2.
|
||||||
|
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||||
|
req := ssh.Marshal(signRequestAgentMsg{ |
||||||
|
KeyBlob: key.Marshal(), |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
|
||||||
|
msg, err := c.call(req) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
switch msg := msg.(type) { |
||||||
|
case *signResponseAgentMsg: |
||||||
|
var sig ssh.Signature |
||||||
|
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &sig, nil |
||||||
|
case *failureAgentMsg: |
||||||
|
return nil, errors.New("agent: failed to sign challenge") |
||||||
|
} |
||||||
|
panic("unreachable") |
||||||
|
} |
||||||
|
|
||||||
|
// unmarshal parses an agent message in packet, returning the parsed
|
||||||
|
// form and the message type of packet.
|
||||||
|
func unmarshal(packet []byte) (interface{}, error) { |
||||||
|
if len(packet) < 1 { |
||||||
|
return nil, errors.New("agent: empty packet") |
||||||
|
} |
||||||
|
var msg interface{} |
||||||
|
switch packet[0] { |
||||||
|
case agentFailure: |
||||||
|
return new(failureAgentMsg), nil |
||||||
|
case agentSuccess: |
||||||
|
return new(successAgentMsg), nil |
||||||
|
case agentIdentitiesAnswer: |
||||||
|
msg = new(identitiesAnswerAgentMsg) |
||||||
|
case agentSignResponse: |
||||||
|
msg = new(signResponseAgentMsg) |
||||||
|
default: |
||||||
|
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) |
||||||
|
} |
||||||
|
if err := ssh.Unmarshal(packet, msg); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return msg, nil |
||||||
|
} |
||||||
|
|
||||||
|
type rsaKeyMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
N *big.Int |
||||||
|
E *big.Int |
||||||
|
D *big.Int |
||||||
|
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||||
|
P *big.Int |
||||||
|
Q *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
type dsaKeyMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
P *big.Int |
||||||
|
Q *big.Int |
||||||
|
G *big.Int |
||||||
|
Y *big.Int |
||||||
|
X *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
type ecdsaKeyMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
Curve string |
||||||
|
KeyBytes []byte |
||||||
|
D *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// Insert adds a private key to the agent.
|
||||||
|
func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { |
||||||
|
var req []byte |
||||||
|
switch k := s.(type) { |
||||||
|
case *rsa.PrivateKey: |
||||||
|
if len(k.Primes) != 2 { |
||||||
|
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||||
|
} |
||||||
|
k.Precompute() |
||||||
|
req = ssh.Marshal(rsaKeyMsg{ |
||||||
|
Type: ssh.KeyAlgoRSA, |
||||||
|
N: k.N, |
||||||
|
E: big.NewInt(int64(k.E)), |
||||||
|
D: k.D, |
||||||
|
Iqmp: k.Precomputed.Qinv, |
||||||
|
P: k.Primes[0], |
||||||
|
Q: k.Primes[1], |
||||||
|
Comments: comment, |
||||||
|
Constraints: constraints, |
||||||
|
}) |
||||||
|
case *dsa.PrivateKey: |
||||||
|
req = ssh.Marshal(dsaKeyMsg{ |
||||||
|
Type: ssh.KeyAlgoDSA, |
||||||
|
P: k.P, |
||||||
|
Q: k.Q, |
||||||
|
G: k.G, |
||||||
|
Y: k.Y, |
||||||
|
X: k.X, |
||||||
|
Comments: comment, |
||||||
|
Constraints: constraints, |
||||||
|
}) |
||||||
|
case *ecdsa.PrivateKey: |
||||||
|
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) |
||||||
|
req = ssh.Marshal(ecdsaKeyMsg{ |
||||||
|
Type: "ecdsa-sha2-" + nistID, |
||||||
|
Curve: nistID, |
||||||
|
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), |
||||||
|
D: k.D, |
||||||
|
Comments: comment, |
||||||
|
Constraints: constraints, |
||||||
|
}) |
||||||
|
default: |
||||||
|
return fmt.Errorf("agent: unsupported key type %T", s) |
||||||
|
} |
||||||
|
|
||||||
|
// if constraints are present then the message type needs to be changed.
|
||||||
|
if len(constraints) != 0 { |
||||||
|
req[0] = agentAddIdConstrained |
||||||
|
} |
||||||
|
|
||||||
|
resp, err := c.call(req) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, ok := resp.(*successAgentMsg); ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return errors.New("agent: failure") |
||||||
|
} |
||||||
|
|
||||||
|
type rsaCertMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
CertBytes []byte |
||||||
|
D *big.Int |
||||||
|
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||||
|
P *big.Int |
||||||
|
Q *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
type dsaCertMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
CertBytes []byte |
||||||
|
X *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
type ecdsaCertMsg struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
CertBytes []byte |
||||||
|
D *big.Int |
||||||
|
Comments string |
||||||
|
Constraints []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// Insert adds a private key to the agent. If a certificate is given,
|
||||||
|
// that certificate is added instead as public key.
|
||||||
|
func (c *client) Add(key AddedKey) error { |
||||||
|
var constraints []byte |
||||||
|
|
||||||
|
if secs := key.LifetimeSecs; secs != 0 { |
||||||
|
constraints = append(constraints, agentConstrainLifetime) |
||||||
|
|
||||||
|
var secsBytes [4]byte |
||||||
|
binary.BigEndian.PutUint32(secsBytes[:], secs) |
||||||
|
constraints = append(constraints, secsBytes[:]...) |
||||||
|
} |
||||||
|
|
||||||
|
if key.ConfirmBeforeUse { |
||||||
|
constraints = append(constraints, agentConstrainConfirm) |
||||||
|
} |
||||||
|
|
||||||
|
if cert := key.Certificate; cert == nil { |
||||||
|
return c.insertKey(key.PrivateKey, key.Comment, constraints) |
||||||
|
} else { |
||||||
|
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { |
||||||
|
var req []byte |
||||||
|
switch k := s.(type) { |
||||||
|
case *rsa.PrivateKey: |
||||||
|
if len(k.Primes) != 2 { |
||||||
|
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||||
|
} |
||||||
|
k.Precompute() |
||||||
|
req = ssh.Marshal(rsaCertMsg{ |
||||||
|
Type: cert.Type(), |
||||||
|
CertBytes: cert.Marshal(), |
||||||
|
D: k.D, |
||||||
|
Iqmp: k.Precomputed.Qinv, |
||||||
|
P: k.Primes[0], |
||||||
|
Q: k.Primes[1], |
||||||
|
Comments: comment, |
||||||
|
Constraints: constraints, |
||||||
|
}) |
||||||
|
case *dsa.PrivateKey: |
||||||
|
req = ssh.Marshal(dsaCertMsg{ |
||||||
|
Type: cert.Type(), |
||||||
|
CertBytes: cert.Marshal(), |
||||||
|
X: k.X, |
||||||
|
Comments: comment, |
||||||
|
}) |
||||||
|
case *ecdsa.PrivateKey: |
||||||
|
req = ssh.Marshal(ecdsaCertMsg{ |
||||||
|
Type: cert.Type(), |
||||||
|
CertBytes: cert.Marshal(), |
||||||
|
D: k.D, |
||||||
|
Comments: comment, |
||||||
|
}) |
||||||
|
default: |
||||||
|
return fmt.Errorf("agent: unsupported key type %T", s) |
||||||
|
} |
||||||
|
|
||||||
|
// if constraints are present then the message type needs to be changed.
|
||||||
|
if len(constraints) != 0 { |
||||||
|
req[0] = agentAddIdConstrained |
||||||
|
} |
||||||
|
|
||||||
|
signer, err := ssh.NewSignerFromKey(s) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||||
|
return errors.New("agent: signer and cert have different public key") |
||||||
|
} |
||||||
|
|
||||||
|
resp, err := c.call(req) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, ok := resp.(*successAgentMsg); ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return errors.New("agent: failure") |
||||||
|
} |
||||||
|
|
||||||
|
// Signers provides a callback for client authentication.
|
||||||
|
func (c *client) Signers() ([]ssh.Signer, error) { |
||||||
|
keys, err := c.List() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var result []ssh.Signer |
||||||
|
for _, k := range keys { |
||||||
|
result = append(result, &agentKeyringSigner{c, k}) |
||||||
|
} |
||||||
|
return result, nil |
||||||
|
} |
||||||
|
|
||||||
|
type agentKeyringSigner struct { |
||||||
|
agent *client |
||||||
|
pub ssh.PublicKey |
||||||
|
} |
||||||
|
|
||||||
|
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { |
||||||
|
return s.pub |
||||||
|
} |
||||||
|
|
||||||
|
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { |
||||||
|
// The agent has its own entropy source, so the rand argument is ignored.
|
||||||
|
return s.agent.Sign(s.pub, data) |
||||||
|
} |
@ -0,0 +1,287 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"errors" |
||||||
|
"net" |
||||||
|
"os" |
||||||
|
"os/exec" |
||||||
|
"path/filepath" |
||||||
|
"strconv" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
||||||
|
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { |
||||||
|
if testing.Short() { |
||||||
|
// ssh-agent is not always available, and the key
|
||||||
|
// types supported vary by platform.
|
||||||
|
t.Skip("skipping test due to -short") |
||||||
|
} |
||||||
|
|
||||||
|
bin, err := exec.LookPath("ssh-agent") |
||||||
|
if err != nil { |
||||||
|
t.Skip("could not find ssh-agent") |
||||||
|
} |
||||||
|
|
||||||
|
cmd := exec.Command(bin, "-s") |
||||||
|
out, err := cmd.Output() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("cmd.Output: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
/* Output looks like: |
||||||
|
|
||||||
|
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; |
||||||
|
SSH_AGENT_PID=15542; export SSH_AGENT_PID; |
||||||
|
echo Agent pid 15542; |
||||||
|
*/ |
||||||
|
fields := bytes.Split(out, []byte(";")) |
||||||
|
line := bytes.SplitN(fields[0], []byte("="), 2) |
||||||
|
line[0] = bytes.TrimLeft(line[0], "\n") |
||||||
|
if string(line[0]) != "SSH_AUTH_SOCK" { |
||||||
|
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) |
||||||
|
} |
||||||
|
socket = string(line[1]) |
||||||
|
|
||||||
|
line = bytes.SplitN(fields[2], []byte("="), 2) |
||||||
|
line[0] = bytes.TrimLeft(line[0], "\n") |
||||||
|
if string(line[0]) != "SSH_AGENT_PID" { |
||||||
|
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) |
||||||
|
} |
||||||
|
pidStr := line[1] |
||||||
|
pid, err := strconv.Atoi(string(pidStr)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Atoi(%q): %v", pidStr, err) |
||||||
|
} |
||||||
|
|
||||||
|
conn, err := net.Dial("unix", string(socket)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("net.Dial: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
ac := NewClient(conn) |
||||||
|
return ac, socket, func() { |
||||||
|
proc, _ := os.FindProcess(pid) |
||||||
|
if proc != nil { |
||||||
|
proc.Kill() |
||||||
|
} |
||||||
|
conn.Close() |
||||||
|
os.RemoveAll(filepath.Dir(socket)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||||
|
agent, _, cleanup := startAgent(t) |
||||||
|
defer cleanup() |
||||||
|
|
||||||
|
testAgentInterface(t, agent, key, cert, lifetimeSecs) |
||||||
|
} |
||||||
|
|
||||||
|
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||||
|
signer, err := ssh.NewSignerFromKey(key) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewSignerFromKey(%T): %v", key, err) |
||||||
|
} |
||||||
|
// The agent should start up empty.
|
||||||
|
if keys, err := agent.List(); err != nil { |
||||||
|
t.Fatalf("RequestIdentities: %v", err) |
||||||
|
} else if len(keys) > 0 { |
||||||
|
t.Fatalf("got %d keys, want 0: %v", len(keys), keys) |
||||||
|
} |
||||||
|
|
||||||
|
// Attempt to insert the key, with certificate if specified.
|
||||||
|
var pubKey ssh.PublicKey |
||||||
|
if cert != nil { |
||||||
|
err = agent.Add(AddedKey{ |
||||||
|
PrivateKey: key, |
||||||
|
Certificate: cert, |
||||||
|
Comment: "comment", |
||||||
|
LifetimeSecs: lifetimeSecs, |
||||||
|
}) |
||||||
|
pubKey = cert |
||||||
|
} else { |
||||||
|
err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) |
||||||
|
pubKey = signer.PublicKey() |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("insert(%T): %v", key, err) |
||||||
|
} |
||||||
|
|
||||||
|
// Did the key get inserted successfully?
|
||||||
|
if keys, err := agent.List(); err != nil { |
||||||
|
t.Fatalf("List: %v", err) |
||||||
|
} else if len(keys) != 1 { |
||||||
|
t.Fatalf("got %v, want 1 key", keys) |
||||||
|
} else if keys[0].Comment != "comment" { |
||||||
|
t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") |
||||||
|
} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { |
||||||
|
t.Fatalf("key mismatch") |
||||||
|
} |
||||||
|
|
||||||
|
// Can the agent make a valid signature?
|
||||||
|
data := []byte("hello") |
||||||
|
sig, err := agent.Sign(pubKey, data) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Sign(%s): %v", pubKey.Type(), err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := pubKey.Verify(data, sig); err != nil { |
||||||
|
t.Fatalf("Verify(%s): %v", pubKey.Type(), err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAgent(t *testing.T) { |
||||||
|
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { |
||||||
|
testAgent(t, testPrivateKeys[keyType], nil, 0) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCert(t *testing.T) { |
||||||
|
cert := &ssh.Certificate{ |
||||||
|
Key: testPublicKeys["rsa"], |
||||||
|
ValidBefore: ssh.CertTimeInfinity, |
||||||
|
CertType: ssh.UserCert, |
||||||
|
} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
|
||||||
|
testAgent(t, testPrivateKeys["rsa"], cert, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func TestConstraints(t *testing.T) { |
||||||
|
testAgent(t, testPrivateKeys["rsa"], nil, 3600 /* lifetime in seconds */) |
||||||
|
} |
||||||
|
|
||||||
|
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||||
|
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||||
|
// a write.)
|
||||||
|
func netPipe() (net.Conn, net.Conn, error) { |
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
defer listener.Close() |
||||||
|
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c2, err := listener.Accept() |
||||||
|
if err != nil { |
||||||
|
c1.Close() |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return c1, c2, nil |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuth(t *testing.T) { |
||||||
|
a, b, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
|
||||||
|
agent, _, cleanup := startAgent(t) |
||||||
|
defer cleanup() |
||||||
|
|
||||||
|
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { |
||||||
|
t.Errorf("Add: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
serverConf := ssh.ServerConfig{} |
||||||
|
serverConf.AddHostKey(testSigners["rsa"]) |
||||||
|
serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
||||||
|
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
return nil, errors.New("pubkey rejected") |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Server: %v", err) |
||||||
|
} |
||||||
|
conn.Close() |
||||||
|
}() |
||||||
|
|
||||||
|
conf := ssh.ClientConfig{} |
||||||
|
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) |
||||||
|
conn, _, _, err := ssh.NewClientConn(b, "", &conf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewClientConn: %v", err) |
||||||
|
} |
||||||
|
conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
func TestLockClient(t *testing.T) { |
||||||
|
agent, _, cleanup := startAgent(t) |
||||||
|
defer cleanup() |
||||||
|
testLockAgent(agent, t) |
||||||
|
} |
||||||
|
|
||||||
|
func testLockAgent(agent Agent, t *testing.T) { |
||||||
|
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { |
||||||
|
t.Errorf("Add: %v", err) |
||||||
|
} |
||||||
|
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { |
||||||
|
t.Errorf("Add: %v", err) |
||||||
|
} |
||||||
|
if keys, err := agent.List(); err != nil { |
||||||
|
t.Errorf("List: %v", err) |
||||||
|
} else if len(keys) != 2 { |
||||||
|
t.Errorf("Want 2 keys, got %v", keys) |
||||||
|
} |
||||||
|
|
||||||
|
passphrase := []byte("secret") |
||||||
|
if err := agent.Lock(passphrase); err != nil { |
||||||
|
t.Errorf("Lock: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if keys, err := agent.List(); err != nil { |
||||||
|
t.Errorf("List: %v", err) |
||||||
|
} else if len(keys) != 0 { |
||||||
|
t.Errorf("Want 0 keys, got %v", keys) |
||||||
|
} |
||||||
|
|
||||||
|
signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) |
||||||
|
if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { |
||||||
|
t.Fatalf("Sign did not fail") |
||||||
|
} |
||||||
|
|
||||||
|
if err := agent.Remove(signer.PublicKey()); err == nil { |
||||||
|
t.Fatalf("Remove did not fail") |
||||||
|
} |
||||||
|
|
||||||
|
if err := agent.RemoveAll(); err == nil { |
||||||
|
t.Fatalf("RemoveAll did not fail") |
||||||
|
} |
||||||
|
|
||||||
|
if err := agent.Unlock(nil); err == nil { |
||||||
|
t.Errorf("Unlock with wrong passphrase succeeded") |
||||||
|
} |
||||||
|
if err := agent.Unlock(passphrase); err != nil { |
||||||
|
t.Errorf("Unlock: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := agent.Remove(signer.PublicKey()); err != nil { |
||||||
|
t.Fatalf("Remove: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if keys, err := agent.List(); err != nil { |
||||||
|
t.Errorf("List: %v", err) |
||||||
|
} else if len(keys) != 1 { |
||||||
|
t.Errorf("Want 1 keys, got %v", keys) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,103 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
// RequestAgentForwarding sets up agent forwarding for the session.
|
||||||
|
// ForwardToAgent or ForwardToRemote should be called to route
|
||||||
|
// the authentication requests.
|
||||||
|
func RequestAgentForwarding(session *ssh.Session) error { |
||||||
|
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if !ok { |
||||||
|
return errors.New("forwarding request denied") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// ForwardToAgent routes authentication requests to the given keyring.
|
||||||
|
func ForwardToAgent(client *ssh.Client, keyring Agent) error { |
||||||
|
channels := client.HandleChannelOpen(channelType) |
||||||
|
if channels == nil { |
||||||
|
return errors.New("agent: already have handler for " + channelType) |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
for ch := range channels { |
||||||
|
channel, reqs, err := ch.Accept() |
||||||
|
if err != nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
go ssh.DiscardRequests(reqs) |
||||||
|
go func() { |
||||||
|
ServeAgent(keyring, channel) |
||||||
|
channel.Close() |
||||||
|
}() |
||||||
|
} |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
const channelType = "auth-agent@openssh.com" |
||||||
|
|
||||||
|
// ForwardToRemote routes authentication requests to the ssh-agent
|
||||||
|
// process serving on the given unix socket.
|
||||||
|
func ForwardToRemote(client *ssh.Client, addr string) error { |
||||||
|
channels := client.HandleChannelOpen(channelType) |
||||||
|
if channels == nil { |
||||||
|
return errors.New("agent: already have handler for " + channelType) |
||||||
|
} |
||||||
|
conn, err := net.Dial("unix", addr) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
conn.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
for ch := range channels { |
||||||
|
channel, reqs, err := ch.Accept() |
||||||
|
if err != nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
go ssh.DiscardRequests(reqs) |
||||||
|
go forwardUnixSocket(channel, addr) |
||||||
|
} |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func forwardUnixSocket(channel ssh.Channel, addr string) { |
||||||
|
conn, err := net.Dial("unix", addr) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
var wg sync.WaitGroup |
||||||
|
wg.Add(2) |
||||||
|
go func() { |
||||||
|
io.Copy(conn, channel) |
||||||
|
conn.(*net.UnixConn).CloseWrite() |
||||||
|
wg.Done() |
||||||
|
}() |
||||||
|
go func() { |
||||||
|
io.Copy(channel, conn) |
||||||
|
channel.CloseWrite() |
||||||
|
wg.Done() |
||||||
|
}() |
||||||
|
|
||||||
|
wg.Wait() |
||||||
|
conn.Close() |
||||||
|
channel.Close() |
||||||
|
} |
@ -0,0 +1,184 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"crypto/subtle" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
type privKey struct { |
||||||
|
signer ssh.Signer |
||||||
|
comment string |
||||||
|
} |
||||||
|
|
||||||
|
type keyring struct { |
||||||
|
mu sync.Mutex |
||||||
|
keys []privKey |
||||||
|
|
||||||
|
locked bool |
||||||
|
passphrase []byte |
||||||
|
} |
||||||
|
|
||||||
|
var errLocked = errors.New("agent: locked") |
||||||
|
|
||||||
|
// NewKeyring returns an Agent that holds keys in memory. It is safe
|
||||||
|
// for concurrent use by multiple goroutines.
|
||||||
|
func NewKeyring() Agent { |
||||||
|
return &keyring{} |
||||||
|
} |
||||||
|
|
||||||
|
// RemoveAll removes all identities.
|
||||||
|
func (r *keyring) RemoveAll() error { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return errLocked |
||||||
|
} |
||||||
|
|
||||||
|
r.keys = nil |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Remove removes all identities with the given public key.
|
||||||
|
func (r *keyring) Remove(key ssh.PublicKey) error { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return errLocked |
||||||
|
} |
||||||
|
|
||||||
|
want := key.Marshal() |
||||||
|
found := false |
||||||
|
for i := 0; i < len(r.keys); { |
||||||
|
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { |
||||||
|
found = true |
||||||
|
r.keys[i] = r.keys[len(r.keys)-1] |
||||||
|
r.keys = r.keys[len(r.keys)-1:] |
||||||
|
continue |
||||||
|
} else { |
||||||
|
i++ |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if !found { |
||||||
|
return errors.New("agent: key not found") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||||
|
func (r *keyring) Lock(passphrase []byte) error { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return errLocked |
||||||
|
} |
||||||
|
|
||||||
|
r.locked = true |
||||||
|
r.passphrase = passphrase |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Unlock undoes the effect of Lock
|
||||||
|
func (r *keyring) Unlock(passphrase []byte) error { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if !r.locked { |
||||||
|
return errors.New("agent: not locked") |
||||||
|
} |
||||||
|
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { |
||||||
|
return fmt.Errorf("agent: incorrect passphrase") |
||||||
|
} |
||||||
|
|
||||||
|
r.locked = false |
||||||
|
r.passphrase = nil |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// List returns the identities known to the agent.
|
||||||
|
func (r *keyring) List() ([]*Key, error) { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
// section 2.7: locked agents return empty.
|
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
var ids []*Key |
||||||
|
for _, k := range r.keys { |
||||||
|
pub := k.signer.PublicKey() |
||||||
|
ids = append(ids, &Key{ |
||||||
|
Format: pub.Type(), |
||||||
|
Blob: pub.Marshal(), |
||||||
|
Comment: k.comment}) |
||||||
|
} |
||||||
|
return ids, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Insert adds a private key to the keyring. If a certificate
|
||||||
|
// is given, that certificate is added as public key. Note that
|
||||||
|
// any constraints given are ignored.
|
||||||
|
func (r *keyring) Add(key AddedKey) error { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return errLocked |
||||||
|
} |
||||||
|
signer, err := ssh.NewSignerFromKey(key.PrivateKey) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if cert := key.Certificate; cert != nil { |
||||||
|
signer, err = ssh.NewCertSigner(cert, signer) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
r.keys = append(r.keys, privKey{signer, key.Comment}) |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Sign returns a signature for the data.
|
||||||
|
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return nil, errLocked |
||||||
|
} |
||||||
|
|
||||||
|
wanted := key.Marshal() |
||||||
|
for _, k := range r.keys { |
||||||
|
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { |
||||||
|
return k.signer.Sign(rand.Reader, data) |
||||||
|
} |
||||||
|
} |
||||||
|
return nil, errors.New("not found") |
||||||
|
} |
||||||
|
|
||||||
|
// Signers returns signers for all the known keys.
|
||||||
|
func (r *keyring) Signers() ([]ssh.Signer, error) { |
||||||
|
r.mu.Lock() |
||||||
|
defer r.mu.Unlock() |
||||||
|
if r.locked { |
||||||
|
return nil, errLocked |
||||||
|
} |
||||||
|
|
||||||
|
s := make([]ssh.Signer, 0, len(r.keys)) |
||||||
|
for _, k := range r.keys { |
||||||
|
s = append(s, k.signer) |
||||||
|
} |
||||||
|
return s, nil |
||||||
|
} |
@ -0,0 +1,209 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rsa" |
||||||
|
"encoding/binary" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"log" |
||||||
|
"math/big" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
// Server wraps an Agent and uses it to implement the agent side of
|
||||||
|
// the SSH-agent, wire protocol.
|
||||||
|
type server struct { |
||||||
|
agent Agent |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) processRequestBytes(reqData []byte) []byte { |
||||||
|
rep, err := s.processRequest(reqData) |
||||||
|
if err != nil { |
||||||
|
if err != errLocked { |
||||||
|
// TODO(hanwen): provide better logging interface?
|
||||||
|
log.Printf("agent %d: %v", reqData[0], err) |
||||||
|
} |
||||||
|
return []byte{agentFailure} |
||||||
|
} |
||||||
|
|
||||||
|
if err == nil && rep == nil { |
||||||
|
return []byte{agentSuccess} |
||||||
|
} |
||||||
|
|
||||||
|
return ssh.Marshal(rep) |
||||||
|
} |
||||||
|
|
||||||
|
func marshalKey(k *Key) []byte { |
||||||
|
var record struct { |
||||||
|
Blob []byte |
||||||
|
Comment string |
||||||
|
} |
||||||
|
record.Blob = k.Marshal() |
||||||
|
record.Comment = k.Comment |
||||||
|
|
||||||
|
return ssh.Marshal(&record) |
||||||
|
} |
||||||
|
|
||||||
|
type agentV1IdentityMsg struct { |
||||||
|
Numkeys uint32 `sshtype:"2"` |
||||||
|
} |
||||||
|
|
||||||
|
type agentRemoveIdentityMsg struct { |
||||||
|
KeyBlob []byte `sshtype:"18"` |
||||||
|
} |
||||||
|
|
||||||
|
type agentLockMsg struct { |
||||||
|
Passphrase []byte `sshtype:"22"` |
||||||
|
} |
||||||
|
|
||||||
|
type agentUnlockMsg struct { |
||||||
|
Passphrase []byte `sshtype:"23"` |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) processRequest(data []byte) (interface{}, error) { |
||||||
|
switch data[0] { |
||||||
|
case agentRequestV1Identities: |
||||||
|
return &agentV1IdentityMsg{0}, nil |
||||||
|
case agentRemoveIdentity: |
||||||
|
var req agentRemoveIdentityMsg |
||||||
|
if err := ssh.Unmarshal(data, &req); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var wk wireKey |
||||||
|
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) |
||||||
|
|
||||||
|
case agentRemoveAllIdentities: |
||||||
|
return nil, s.agent.RemoveAll() |
||||||
|
|
||||||
|
case agentLock: |
||||||
|
var req agentLockMsg |
||||||
|
if err := ssh.Unmarshal(data, &req); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return nil, s.agent.Lock(req.Passphrase) |
||||||
|
|
||||||
|
case agentUnlock: |
||||||
|
var req agentLockMsg |
||||||
|
if err := ssh.Unmarshal(data, &req); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return nil, s.agent.Unlock(req.Passphrase) |
||||||
|
|
||||||
|
case agentSignRequest: |
||||||
|
var req signRequestAgentMsg |
||||||
|
if err := ssh.Unmarshal(data, &req); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var wk wireKey |
||||||
|
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
k := &Key{ |
||||||
|
Format: wk.Format, |
||||||
|
Blob: req.KeyBlob, |
||||||
|
} |
||||||
|
|
||||||
|
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
|
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil |
||||||
|
case agentRequestIdentities: |
||||||
|
keys, err := s.agent.List() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
rep := identitiesAnswerAgentMsg{ |
||||||
|
NumKeys: uint32(len(keys)), |
||||||
|
} |
||||||
|
for _, k := range keys { |
||||||
|
rep.Keys = append(rep.Keys, marshalKey(k)...) |
||||||
|
} |
||||||
|
return rep, nil |
||||||
|
case agentAddIdentity: |
||||||
|
return nil, s.insertIdentity(data) |
||||||
|
} |
||||||
|
|
||||||
|
return nil, fmt.Errorf("unknown opcode %d", data[0]) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) insertIdentity(req []byte) error { |
||||||
|
var record struct { |
||||||
|
Type string `sshtype:"17"` |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
if err := ssh.Unmarshal(req, &record); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
switch record.Type { |
||||||
|
case ssh.KeyAlgoRSA: |
||||||
|
var k rsaKeyMsg |
||||||
|
if err := ssh.Unmarshal(req, &k); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
priv := rsa.PrivateKey{ |
||||||
|
PublicKey: rsa.PublicKey{ |
||||||
|
E: int(k.E.Int64()), |
||||||
|
N: k.N, |
||||||
|
}, |
||||||
|
D: k.D, |
||||||
|
Primes: []*big.Int{k.P, k.Q}, |
||||||
|
} |
||||||
|
priv.Precompute() |
||||||
|
|
||||||
|
return s.agent.Add(AddedKey{PrivateKey: &priv, Comment: k.Comments}) |
||||||
|
} |
||||||
|
return fmt.Errorf("not implemented: %s", record.Type) |
||||||
|
} |
||||||
|
|
||||||
|
// ServeAgent serves the agent protocol on the given connection. It
|
||||||
|
// returns when an I/O error occurs.
|
||||||
|
func ServeAgent(agent Agent, c io.ReadWriter) error { |
||||||
|
s := &server{agent} |
||||||
|
|
||||||
|
var length [4]byte |
||||||
|
for { |
||||||
|
if _, err := io.ReadFull(c, length[:]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
l := binary.BigEndian.Uint32(length[:]) |
||||||
|
if l > maxAgentResponseBytes { |
||||||
|
// We also cap requests.
|
||||||
|
return fmt.Errorf("agent: request too large: %d", l) |
||||||
|
} |
||||||
|
|
||||||
|
req := make([]byte, l) |
||||||
|
if _, err := io.ReadFull(c, req); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
repData := s.processRequestBytes(req) |
||||||
|
if len(repData) > maxAgentResponseBytes { |
||||||
|
return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) |
||||||
|
} |
||||||
|
|
||||||
|
binary.BigEndian.PutUint32(length[:], uint32(len(repData))) |
||||||
|
if _, err := c.Write(length[:]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, err := c.Write(repData); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,77 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
func TestServer(t *testing.T) { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
client := NewClient(c1) |
||||||
|
|
||||||
|
go ServeAgent(NewKeyring(), c2) |
||||||
|
|
||||||
|
testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func TestLockServer(t *testing.T) { |
||||||
|
testLockAgent(NewKeyring(), t) |
||||||
|
} |
||||||
|
|
||||||
|
func TestSetupForwardAgent(t *testing.T) { |
||||||
|
a, b, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
|
||||||
|
_, socket, cleanup := startAgent(t) |
||||||
|
defer cleanup() |
||||||
|
|
||||||
|
serverConf := ssh.ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
serverConf.AddHostKey(testSigners["rsa"]) |
||||||
|
incoming := make(chan *ssh.ServerConn, 1) |
||||||
|
go func() { |
||||||
|
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Server: %v", err) |
||||||
|
} |
||||||
|
incoming <- conn |
||||||
|
}() |
||||||
|
|
||||||
|
conf := ssh.ClientConfig{} |
||||||
|
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewClientConn: %v", err) |
||||||
|
} |
||||||
|
client := ssh.NewClient(conn, chans, reqs) |
||||||
|
|
||||||
|
if err := ForwardToRemote(client, socket); err != nil { |
||||||
|
t.Fatalf("SetupForwardAgent: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
server := <-incoming |
||||||
|
ch, reqs, err := server.OpenChannel(channelType, nil) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("OpenChannel(%q): %v", channelType, err) |
||||||
|
} |
||||||
|
go ssh.DiscardRequests(reqs) |
||||||
|
|
||||||
|
agentClient := NewClient(ch) |
||||||
|
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) |
||||||
|
conn.Close() |
||||||
|
} |
@ -0,0 +1,64 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
||||||
|
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
||||||
|
// instances.
|
||||||
|
|
||||||
|
package agent |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh/testdata" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
testPrivateKeys map[string]interface{} |
||||||
|
testSigners map[string]ssh.Signer |
||||||
|
testPublicKeys map[string]ssh.PublicKey |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
var err error |
||||||
|
|
||||||
|
n := len(testdata.PEMBytes) |
||||||
|
testPrivateKeys = make(map[string]interface{}, n) |
||||||
|
testSigners = make(map[string]ssh.Signer, n) |
||||||
|
testPublicKeys = make(map[string]ssh.PublicKey, n) |
||||||
|
for t, k := range testdata.PEMBytes { |
||||||
|
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testPublicKeys[t] = testSigners[t].PublicKey() |
||||||
|
} |
||||||
|
|
||||||
|
// Create a cert and sign it for use in tests.
|
||||||
|
testCert := &ssh.Certificate{ |
||||||
|
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
||||||
|
ValidAfter: 0, // unix epoch
|
||||||
|
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
||||||
|
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
Key: testPublicKeys["ecdsa"], |
||||||
|
SignatureKey: testPublicKeys["rsa"], |
||||||
|
Permissions: ssh.Permissions{ |
||||||
|
CriticalOptions: map[string]string{}, |
||||||
|
Extensions: map[string]string{}, |
||||||
|
}, |
||||||
|
} |
||||||
|
testCert.SignCert(rand.Reader, testSigners["rsa"]) |
||||||
|
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] |
||||||
|
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,122 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
type server struct { |
||||||
|
*ServerConn |
||||||
|
chans <-chan NewChannel |
||||||
|
} |
||||||
|
|
||||||
|
func newServer(c net.Conn, conf *ServerConfig) (*server, error) { |
||||||
|
sconn, chans, reqs, err := NewServerConn(c, conf) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
go DiscardRequests(reqs) |
||||||
|
return &server{sconn, chans}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Accept() (NewChannel, error) { |
||||||
|
n, ok := <-s.chans |
||||||
|
if !ok { |
||||||
|
return nil, io.EOF |
||||||
|
} |
||||||
|
return n, nil |
||||||
|
} |
||||||
|
|
||||||
|
func sshPipe() (Conn, *server, error) { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
clientConf := ClientConfig{ |
||||||
|
User: "user", |
||||||
|
} |
||||||
|
serverConf := ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||||
|
done := make(chan *server, 1) |
||||||
|
go func() { |
||||||
|
server, err := newServer(c2, &serverConf) |
||||||
|
if err != nil { |
||||||
|
done <- nil |
||||||
|
} |
||||||
|
done <- server |
||||||
|
}() |
||||||
|
|
||||||
|
client, _, reqs, err := NewClientConn(c1, "", &clientConf) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
server := <-done |
||||||
|
if server == nil { |
||||||
|
return nil, nil, errors.New("server handshake failed.") |
||||||
|
} |
||||||
|
go DiscardRequests(reqs) |
||||||
|
|
||||||
|
return client, server, nil |
||||||
|
} |
||||||
|
|
||||||
|
func BenchmarkEndToEnd(b *testing.B) { |
||||||
|
b.StopTimer() |
||||||
|
|
||||||
|
client, server, err := sshPipe() |
||||||
|
if err != nil { |
||||||
|
b.Fatalf("sshPipe: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
defer client.Close() |
||||||
|
defer server.Close() |
||||||
|
|
||||||
|
size := (1 << 20) |
||||||
|
input := make([]byte, size) |
||||||
|
output := make([]byte, size) |
||||||
|
b.SetBytes(int64(size)) |
||||||
|
done := make(chan int, 1) |
||||||
|
|
||||||
|
go func() { |
||||||
|
newCh, err := server.Accept() |
||||||
|
if err != nil { |
||||||
|
b.Fatalf("Client: %v", err) |
||||||
|
} |
||||||
|
ch, incoming, err := newCh.Accept() |
||||||
|
go DiscardRequests(incoming) |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
if _, err := io.ReadFull(ch, output); err != nil { |
||||||
|
b.Fatalf("ReadFull: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
ch.Close() |
||||||
|
done <- 1 |
||||||
|
}() |
||||||
|
|
||||||
|
ch, in, err := client.OpenChannel("speed", nil) |
||||||
|
if err != nil { |
||||||
|
b.Fatalf("OpenChannel: %v", err) |
||||||
|
} |
||||||
|
go DiscardRequests(in) |
||||||
|
|
||||||
|
b.ResetTimer() |
||||||
|
b.StartTimer() |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
if _, err := ch.Write(input); err != nil { |
||||||
|
b.Fatalf("WriteFull: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
ch.Close() |
||||||
|
b.StopTimer() |
||||||
|
|
||||||
|
<-done |
||||||
|
} |
@ -0,0 +1,98 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
// buffer provides a linked list buffer for data exchange
|
||||||
|
// between producer and consumer. Theoretically the buffer is
|
||||||
|
// of unlimited capacity as it does no allocation of its own.
|
||||||
|
type buffer struct { |
||||||
|
// protects concurrent access to head, tail and closed
|
||||||
|
*sync.Cond |
||||||
|
|
||||||
|
head *element // the buffer that will be read first
|
||||||
|
tail *element // the buffer that will be read last
|
||||||
|
|
||||||
|
closed bool |
||||||
|
} |
||||||
|
|
||||||
|
// An element represents a single link in a linked list.
|
||||||
|
type element struct { |
||||||
|
buf []byte |
||||||
|
next *element |
||||||
|
} |
||||||
|
|
||||||
|
// newBuffer returns an empty buffer that is not closed.
|
||||||
|
func newBuffer() *buffer { |
||||||
|
e := new(element) |
||||||
|
b := &buffer{ |
||||||
|
Cond: newCond(), |
||||||
|
head: e, |
||||||
|
tail: e, |
||||||
|
} |
||||||
|
return b |
||||||
|
} |
||||||
|
|
||||||
|
// write makes buf available for Read to receive.
|
||||||
|
// buf must not be modified after the call to write.
|
||||||
|
func (b *buffer) write(buf []byte) { |
||||||
|
b.Cond.L.Lock() |
||||||
|
e := &element{buf: buf} |
||||||
|
b.tail.next = e |
||||||
|
b.tail = e |
||||||
|
b.Cond.Signal() |
||||||
|
b.Cond.L.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// eof closes the buffer. Reads from the buffer once all
|
||||||
|
// the data has been consumed will receive os.EOF.
|
||||||
|
func (b *buffer) eof() error { |
||||||
|
b.Cond.L.Lock() |
||||||
|
b.closed = true |
||||||
|
b.Cond.Signal() |
||||||
|
b.Cond.L.Unlock() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Read reads data from the internal buffer in buf. Reads will block
|
||||||
|
// if no data is available, or until the buffer is closed.
|
||||||
|
func (b *buffer) Read(buf []byte) (n int, err error) { |
||||||
|
b.Cond.L.Lock() |
||||||
|
defer b.Cond.L.Unlock() |
||||||
|
|
||||||
|
for len(buf) > 0 { |
||||||
|
// if there is data in b.head, copy it
|
||||||
|
if len(b.head.buf) > 0 { |
||||||
|
r := copy(buf, b.head.buf) |
||||||
|
buf, b.head.buf = buf[r:], b.head.buf[r:] |
||||||
|
n += r |
||||||
|
continue |
||||||
|
} |
||||||
|
// if there is a next buffer, make it the head
|
||||||
|
if len(b.head.buf) == 0 && b.head != b.tail { |
||||||
|
b.head = b.head.next |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
// if at least one byte has been copied, return
|
||||||
|
if n > 0 { |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
// if nothing was read, and there is nothing outstanding
|
||||||
|
// check to see if the buffer is closed.
|
||||||
|
if b.closed { |
||||||
|
err = io.EOF |
||||||
|
break |
||||||
|
} |
||||||
|
// out of buffers, wait for producer
|
||||||
|
b.Cond.Wait() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,87 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") |
||||||
|
|
||||||
|
func TestBufferReadwrite(t *testing.T) { |
||||||
|
b := newBuffer() |
||||||
|
b.write(alphabet[:10]) |
||||||
|
r, _ := b.Read(make([]byte, 10)) |
||||||
|
if r != 10 { |
||||||
|
t.Fatalf("Expected written == read == 10, written: 10, read %d", r) |
||||||
|
} |
||||||
|
|
||||||
|
b = newBuffer() |
||||||
|
b.write(alphabet[:5]) |
||||||
|
r, _ = b.Read(make([]byte, 10)) |
||||||
|
if r != 5 { |
||||||
|
t.Fatalf("Expected written == read == 5, written: 5, read %d", r) |
||||||
|
} |
||||||
|
|
||||||
|
b = newBuffer() |
||||||
|
b.write(alphabet[:10]) |
||||||
|
r, _ = b.Read(make([]byte, 5)) |
||||||
|
if r != 5 { |
||||||
|
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) |
||||||
|
} |
||||||
|
|
||||||
|
b = newBuffer() |
||||||
|
b.write(alphabet[:5]) |
||||||
|
b.write(alphabet[5:15]) |
||||||
|
r, _ = b.Read(make([]byte, 10)) |
||||||
|
r2, _ := b.Read(make([]byte, 10)) |
||||||
|
if r != 10 || r2 != 5 || 15 != r+r2 { |
||||||
|
t.Fatal("Expected written == read == 15") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestBufferClose(t *testing.T) { |
||||||
|
b := newBuffer() |
||||||
|
b.write(alphabet[:10]) |
||||||
|
b.eof() |
||||||
|
_, err := b.Read(make([]byte, 5)) |
||||||
|
if err != nil { |
||||||
|
t.Fatal("expected read of 5 to not return EOF") |
||||||
|
} |
||||||
|
b = newBuffer() |
||||||
|
b.write(alphabet[:10]) |
||||||
|
b.eof() |
||||||
|
r, err := b.Read(make([]byte, 5)) |
||||||
|
r2, err2 := b.Read(make([]byte, 10)) |
||||||
|
if r != 5 || r2 != 5 || err != nil || err2 != nil { |
||||||
|
t.Fatal("expected reads of 5 and 5") |
||||||
|
} |
||||||
|
|
||||||
|
b = newBuffer() |
||||||
|
b.write(alphabet[:10]) |
||||||
|
b.eof() |
||||||
|
r, err = b.Read(make([]byte, 5)) |
||||||
|
r2, err2 = b.Read(make([]byte, 10)) |
||||||
|
r3, err3 := b.Read(make([]byte, 10)) |
||||||
|
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { |
||||||
|
t.Fatal("expected reads of 5 and 5 and 0, with EOF") |
||||||
|
} |
||||||
|
|
||||||
|
b = newBuffer() |
||||||
|
b.write(make([]byte, 5)) |
||||||
|
b.write(make([]byte, 10)) |
||||||
|
b.eof() |
||||||
|
r, err = b.Read(make([]byte, 9)) |
||||||
|
r2, err2 = b.Read(make([]byte, 3)) |
||||||
|
r3, err3 = b.Read(make([]byte, 3)) |
||||||
|
r4, err4 := b.Read(make([]byte, 10)) |
||||||
|
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { |
||||||
|
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) |
||||||
|
} |
||||||
|
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { |
||||||
|
t.Fatal("Expected written == read == 15", r, r2, r3, r4) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,501 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"sort" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
// These constants from [PROTOCOL.certkeys] represent the algorithm names
|
||||||
|
// for certificate types supported by this package.
|
||||||
|
const ( |
||||||
|
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" |
||||||
|
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" |
||||||
|
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" |
||||||
|
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" |
||||||
|
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" |
||||||
|
) |
||||||
|
|
||||||
|
// Certificate types distinguish between host and user
|
||||||
|
// certificates. The values can be set in the CertType field of
|
||||||
|
// Certificate.
|
||||||
|
const ( |
||||||
|
UserCert = 1 |
||||||
|
HostCert = 2 |
||||||
|
) |
||||||
|
|
||||||
|
// Signature represents a cryptographic signature.
|
||||||
|
type Signature struct { |
||||||
|
Format string |
||||||
|
Blob []byte |
||||||
|
} |
||||||
|
|
||||||
|
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
||||||
|
// a certificate does not expire.
|
||||||
|
const CertTimeInfinity = 1<<64 - 1 |
||||||
|
|
||||||
|
// An Certificate represents an OpenSSH certificate as defined in
|
||||||
|
// [PROTOCOL.certkeys]?rev=1.8.
|
||||||
|
type Certificate struct { |
||||||
|
Nonce []byte |
||||||
|
Key PublicKey |
||||||
|
Serial uint64 |
||||||
|
CertType uint32 |
||||||
|
KeyId string |
||||||
|
ValidPrincipals []string |
||||||
|
ValidAfter uint64 |
||||||
|
ValidBefore uint64 |
||||||
|
Permissions |
||||||
|
Reserved []byte |
||||||
|
SignatureKey PublicKey |
||||||
|
Signature *Signature |
||||||
|
} |
||||||
|
|
||||||
|
// genericCertData holds the key-independent part of the certificate data.
|
||||||
|
// Overall, certificates contain an nonce, public key fields and
|
||||||
|
// key-independent fields.
|
||||||
|
type genericCertData struct { |
||||||
|
Serial uint64 |
||||||
|
CertType uint32 |
||||||
|
KeyId string |
||||||
|
ValidPrincipals []byte |
||||||
|
ValidAfter uint64 |
||||||
|
ValidBefore uint64 |
||||||
|
CriticalOptions []byte |
||||||
|
Extensions []byte |
||||||
|
Reserved []byte |
||||||
|
SignatureKey []byte |
||||||
|
Signature []byte |
||||||
|
} |
||||||
|
|
||||||
|
func marshalStringList(namelist []string) []byte { |
||||||
|
var to []byte |
||||||
|
for _, name := range namelist { |
||||||
|
s := struct{ N string }{name} |
||||||
|
to = append(to, Marshal(&s)...) |
||||||
|
} |
||||||
|
return to |
||||||
|
} |
||||||
|
|
||||||
|
type optionsTuple struct { |
||||||
|
Key string |
||||||
|
Value []byte |
||||||
|
} |
||||||
|
|
||||||
|
type optionsTupleValue struct { |
||||||
|
Value string |
||||||
|
} |
||||||
|
|
||||||
|
// serialize a map of critical options or extensions
|
||||||
|
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||||
|
// we need two length prefixes for a non-empty string value
|
||||||
|
func marshalTuples(tups map[string]string) []byte { |
||||||
|
keys := make([]string, 0, len(tups)) |
||||||
|
for key := range tups { |
||||||
|
keys = append(keys, key) |
||||||
|
} |
||||||
|
sort.Strings(keys) |
||||||
|
|
||||||
|
var ret []byte |
||||||
|
for _, key := range keys { |
||||||
|
s := optionsTuple{Key: key} |
||||||
|
if value := tups[key]; len(value) > 0 { |
||||||
|
s.Value = Marshal(&optionsTupleValue{value}) |
||||||
|
} |
||||||
|
ret = append(ret, Marshal(&s)...) |
||||||
|
} |
||||||
|
return ret |
||||||
|
} |
||||||
|
|
||||||
|
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||||
|
// we need two length prefixes for a non-empty option value
|
||||||
|
func parseTuples(in []byte) (map[string]string, error) { |
||||||
|
tups := map[string]string{} |
||||||
|
var lastKey string |
||||||
|
var haveLastKey bool |
||||||
|
|
||||||
|
for len(in) > 0 { |
||||||
|
var key, val, extra []byte |
||||||
|
var ok bool |
||||||
|
|
||||||
|
if key, in, ok = parseString(in); !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
keyStr := string(key) |
||||||
|
// according to [PROTOCOL.certkeys], the names must be in
|
||||||
|
// lexical order.
|
||||||
|
if haveLastKey && keyStr <= lastKey { |
||||||
|
return nil, fmt.Errorf("ssh: certificate options are not in lexical order") |
||||||
|
} |
||||||
|
lastKey, haveLastKey = keyStr, true |
||||||
|
// the next field is a data field, which if non-empty has a string embedded
|
||||||
|
if val, in, ok = parseString(in); !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
if len(val) > 0 { |
||||||
|
val, extra, ok = parseString(val) |
||||||
|
if !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
if len(extra) > 0 { |
||||||
|
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") |
||||||
|
} |
||||||
|
tups[keyStr] = string(val) |
||||||
|
} else { |
||||||
|
tups[keyStr] = "" |
||||||
|
} |
||||||
|
} |
||||||
|
return tups, nil |
||||||
|
} |
||||||
|
|
||||||
|
func parseCert(in []byte, privAlgo string) (*Certificate, error) { |
||||||
|
nonce, rest, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
|
||||||
|
key, rest, err := parsePubKey(rest, privAlgo) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var g genericCertData |
||||||
|
if err := Unmarshal(rest, &g); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c := &Certificate{ |
||||||
|
Nonce: nonce, |
||||||
|
Key: key, |
||||||
|
Serial: g.Serial, |
||||||
|
CertType: g.CertType, |
||||||
|
KeyId: g.KeyId, |
||||||
|
ValidAfter: g.ValidAfter, |
||||||
|
ValidBefore: g.ValidBefore, |
||||||
|
} |
||||||
|
|
||||||
|
for principals := g.ValidPrincipals; len(principals) > 0; { |
||||||
|
principal, rest, ok := parseString(principals) |
||||||
|
if !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) |
||||||
|
principals = rest |
||||||
|
} |
||||||
|
|
||||||
|
c.CriticalOptions, err = parseTuples(g.CriticalOptions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
c.Extensions, err = parseTuples(g.Extensions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
c.Reserved = g.Reserved |
||||||
|
k, err := ParsePublicKey(g.SignatureKey) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c.SignatureKey = k |
||||||
|
c.Signature, rest, ok = parseSignatureBody(g.Signature) |
||||||
|
if !ok || len(rest) > 0 { |
||||||
|
return nil, errors.New("ssh: signature parse error") |
||||||
|
} |
||||||
|
|
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
type openSSHCertSigner struct { |
||||||
|
pub *Certificate |
||||||
|
signer Signer |
||||||
|
} |
||||||
|
|
||||||
|
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
||||||
|
// private key is held by signer. It returns an error if the public key in cert
|
||||||
|
// doesn't match the key used by signer.
|
||||||
|
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { |
||||||
|
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||||
|
return nil, errors.New("ssh: signer and cert have different public key") |
||||||
|
} |
||||||
|
|
||||||
|
return &openSSHCertSigner{cert, signer}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||||
|
return s.signer.Sign(rand, data) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *openSSHCertSigner) PublicKey() PublicKey { |
||||||
|
return s.pub |
||||||
|
} |
||||||
|
|
||||||
|
const sourceAddressCriticalOption = "source-address" |
||||||
|
|
||||||
|
// CertChecker does the work of verifying a certificate. Its methods
|
||||||
|
// can be plugged into ClientConfig.HostKeyCallback and
|
||||||
|
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
||||||
|
// minimally, the IsAuthority callback should be set.
|
||||||
|
type CertChecker struct { |
||||||
|
// SupportedCriticalOptions lists the CriticalOptions that the
|
||||||
|
// server application layer understands. These are only used
|
||||||
|
// for user certificates.
|
||||||
|
SupportedCriticalOptions []string |
||||||
|
|
||||||
|
// IsAuthority should return true if the key is recognized as
|
||||||
|
// an authority. This allows for certificates to be signed by other
|
||||||
|
// certificates.
|
||||||
|
IsAuthority func(auth PublicKey) bool |
||||||
|
|
||||||
|
// Clock is used for verifying time stamps. If nil, time.Now
|
||||||
|
// is used.
|
||||||
|
Clock func() time.Time |
||||||
|
|
||||||
|
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
||||||
|
// public key that is not a certificate. It must implement validation
|
||||||
|
// of user keys or else, if nil, all such keys are rejected.
|
||||||
|
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||||
|
|
||||||
|
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
||||||
|
// public key that is not a certificate. It must implement host key
|
||||||
|
// validation or else, if nil, all such keys are rejected.
|
||||||
|
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error |
||||||
|
|
||||||
|
// IsRevoked is called for each certificate so that revocation checking
|
||||||
|
// can be implemented. It should return true if the given certificate
|
||||||
|
// is revoked and false otherwise. If nil, no certificates are
|
||||||
|
// considered to have been revoked.
|
||||||
|
IsRevoked func(cert *Certificate) bool |
||||||
|
} |
||||||
|
|
||||||
|
// CheckHostKey checks a host key certificate. This method can be
|
||||||
|
// plugged into ClientConfig.HostKeyCallback.
|
||||||
|
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { |
||||||
|
cert, ok := key.(*Certificate) |
||||||
|
if !ok { |
||||||
|
if c.HostKeyFallback != nil { |
||||||
|
return c.HostKeyFallback(addr, remote, key) |
||||||
|
} |
||||||
|
return errors.New("ssh: non-certificate host key") |
||||||
|
} |
||||||
|
if cert.CertType != HostCert { |
||||||
|
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) |
||||||
|
} |
||||||
|
|
||||||
|
return c.CheckCert(addr, cert) |
||||||
|
} |
||||||
|
|
||||||
|
// Authenticate checks a user certificate. Authenticate can be used as
|
||||||
|
// a value for ServerConfig.PublicKeyCallback.
|
||||||
|
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { |
||||||
|
cert, ok := pubKey.(*Certificate) |
||||||
|
if !ok { |
||||||
|
if c.UserKeyFallback != nil { |
||||||
|
return c.UserKeyFallback(conn, pubKey) |
||||||
|
} |
||||||
|
return nil, errors.New("ssh: normal key pairs not accepted") |
||||||
|
} |
||||||
|
|
||||||
|
if cert.CertType != UserCert { |
||||||
|
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) |
||||||
|
} |
||||||
|
|
||||||
|
if err := c.CheckCert(conn.User(), cert); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &cert.Permissions, nil |
||||||
|
} |
||||||
|
|
||||||
|
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
||||||
|
// the signature of the certificate.
|
||||||
|
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { |
||||||
|
if c.IsRevoked != nil && c.IsRevoked(cert) { |
||||||
|
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) |
||||||
|
} |
||||||
|
|
||||||
|
for opt, _ := range cert.CriticalOptions { |
||||||
|
// sourceAddressCriticalOption will be enforced by
|
||||||
|
// serverAuthenticate
|
||||||
|
if opt == sourceAddressCriticalOption { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
found := false |
||||||
|
for _, supp := range c.SupportedCriticalOptions { |
||||||
|
if supp == opt { |
||||||
|
found = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
if !found { |
||||||
|
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if len(cert.ValidPrincipals) > 0 { |
||||||
|
// By default, certs are valid for all users/hosts.
|
||||||
|
found := false |
||||||
|
for _, p := range cert.ValidPrincipals { |
||||||
|
if p == principal { |
||||||
|
found = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
if !found { |
||||||
|
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if !c.IsAuthority(cert.SignatureKey) { |
||||||
|
return fmt.Errorf("ssh: certificate signed by unrecognized authority") |
||||||
|
} |
||||||
|
|
||||||
|
clock := c.Clock |
||||||
|
if clock == nil { |
||||||
|
clock = time.Now |
||||||
|
} |
||||||
|
|
||||||
|
unixNow := clock().Unix() |
||||||
|
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { |
||||||
|
return fmt.Errorf("ssh: cert is not yet valid") |
||||||
|
} |
||||||
|
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { |
||||||
|
return fmt.Errorf("ssh: cert has expired") |
||||||
|
} |
||||||
|
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { |
||||||
|
return fmt.Errorf("ssh: certificate signature does not verify") |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// SignCert sets c.SignatureKey to the authority's public key and stores a
|
||||||
|
// Signature, by authority, in the certificate.
|
||||||
|
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { |
||||||
|
c.Nonce = make([]byte, 32) |
||||||
|
if _, err := io.ReadFull(rand, c.Nonce); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
c.SignatureKey = authority.PublicKey() |
||||||
|
|
||||||
|
sig, err := authority.Sign(rand, c.bytesForSigning()) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
c.Signature = sig |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
var certAlgoNames = map[string]string{ |
||||||
|
KeyAlgoRSA: CertAlgoRSAv01, |
||||||
|
KeyAlgoDSA: CertAlgoDSAv01, |
||||||
|
KeyAlgoECDSA256: CertAlgoECDSA256v01, |
||||||
|
KeyAlgoECDSA384: CertAlgoECDSA384v01, |
||||||
|
KeyAlgoECDSA521: CertAlgoECDSA521v01, |
||||||
|
} |
||||||
|
|
||||||
|
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
|
||||||
|
// Panics if a non-certificate algorithm is passed.
|
||||||
|
func certToPrivAlgo(algo string) string { |
||||||
|
for privAlgo, pubAlgo := range certAlgoNames { |
||||||
|
if pubAlgo == algo { |
||||||
|
return privAlgo |
||||||
|
} |
||||||
|
} |
||||||
|
panic("unknown cert algorithm") |
||||||
|
} |
||||||
|
|
||||||
|
func (cert *Certificate) bytesForSigning() []byte { |
||||||
|
c2 := *cert |
||||||
|
c2.Signature = nil |
||||||
|
out := c2.Marshal() |
||||||
|
// Drop trailing signature length.
|
||||||
|
return out[:len(out)-4] |
||||||
|
} |
||||||
|
|
||||||
|
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
||||||
|
// PublicKey interface.
|
||||||
|
func (c *Certificate) Marshal() []byte { |
||||||
|
generic := genericCertData{ |
||||||
|
Serial: c.Serial, |
||||||
|
CertType: c.CertType, |
||||||
|
KeyId: c.KeyId, |
||||||
|
ValidPrincipals: marshalStringList(c.ValidPrincipals), |
||||||
|
ValidAfter: uint64(c.ValidAfter), |
||||||
|
ValidBefore: uint64(c.ValidBefore), |
||||||
|
CriticalOptions: marshalTuples(c.CriticalOptions), |
||||||
|
Extensions: marshalTuples(c.Extensions), |
||||||
|
Reserved: c.Reserved, |
||||||
|
SignatureKey: c.SignatureKey.Marshal(), |
||||||
|
} |
||||||
|
if c.Signature != nil { |
||||||
|
generic.Signature = Marshal(c.Signature) |
||||||
|
} |
||||||
|
genericBytes := Marshal(&generic) |
||||||
|
keyBytes := c.Key.Marshal() |
||||||
|
_, keyBytes, _ = parseString(keyBytes) |
||||||
|
prefix := Marshal(&struct { |
||||||
|
Name string |
||||||
|
Nonce []byte |
||||||
|
Key []byte `ssh:"rest"` |
||||||
|
}{c.Type(), c.Nonce, keyBytes}) |
||||||
|
|
||||||
|
result := make([]byte, 0, len(prefix)+len(genericBytes)) |
||||||
|
result = append(result, prefix...) |
||||||
|
result = append(result, genericBytes...) |
||||||
|
return result |
||||||
|
} |
||||||
|
|
||||||
|
// Type returns the key name. It is part of the PublicKey interface.
|
||||||
|
func (c *Certificate) Type() string { |
||||||
|
algo, ok := certAlgoNames[c.Key.Type()] |
||||||
|
if !ok { |
||||||
|
panic("unknown cert key type") |
||||||
|
} |
||||||
|
return algo |
||||||
|
} |
||||||
|
|
||||||
|
// Verify verifies a signature against the certificate's public
|
||||||
|
// key. It is part of the PublicKey interface.
|
||||||
|
func (c *Certificate) Verify(data []byte, sig *Signature) error { |
||||||
|
return c.Key.Verify(data, sig) |
||||||
|
} |
||||||
|
|
||||||
|
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { |
||||||
|
format, in, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
out = &Signature{ |
||||||
|
Format: string(format), |
||||||
|
} |
||||||
|
|
||||||
|
if out.Blob, in, ok = parseString(in); !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
return out, in, ok |
||||||
|
} |
||||||
|
|
||||||
|
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { |
||||||
|
sigBytes, rest, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
out, trailing, ok := parseSignatureBody(sigBytes) |
||||||
|
if !ok || len(trailing) > 0 { |
||||||
|
return nil, nil, false |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,216 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"reflect" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
// Cert generated by ssh-keygen 6.0p1 Debian-4.
|
||||||
|
// % ssh-keygen -s ca-key -I test user-key
|
||||||
|
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` |
||||||
|
|
||||||
|
func TestParseCert(t *testing.T) { |
||||||
|
authKeyBytes := []byte(exampleSSHCert) |
||||||
|
|
||||||
|
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||||
|
} |
||||||
|
if len(rest) > 0 { |
||||||
|
t.Errorf("rest: got %q, want empty", rest) |
||||||
|
} |
||||||
|
|
||||||
|
if _, ok := key.(*Certificate); !ok { |
||||||
|
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||||
|
} |
||||||
|
|
||||||
|
marshaled := MarshalAuthorizedKey(key) |
||||||
|
// Before comparison, remove the trailing newline that
|
||||||
|
// MarshalAuthorizedKey adds.
|
||||||
|
marshaled = marshaled[:len(marshaled)-1] |
||||||
|
if !bytes.Equal(authKeyBytes, marshaled) { |
||||||
|
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
|
||||||
|
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
|
||||||
|
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
|
||||||
|
// Critical Options:
|
||||||
|
// force-command /bin/sleep
|
||||||
|
// source-address 192.168.1.0/24
|
||||||
|
// Extensions:
|
||||||
|
// permit-X11-forwarding
|
||||||
|
// permit-agent-forwarding
|
||||||
|
// permit-port-forwarding
|
||||||
|
// permit-pty
|
||||||
|
// permit-user-rc
|
||||||
|
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` |
||||||
|
|
||||||
|
func TestParseCertWithOptions(t *testing.T) { |
||||||
|
opts := map[string]string{ |
||||||
|
"source-address": "192.168.1.0/24", |
||||||
|
"force-command": "/bin/sleep", |
||||||
|
} |
||||||
|
exts := map[string]string{ |
||||||
|
"permit-X11-forwarding": "", |
||||||
|
"permit-agent-forwarding": "", |
||||||
|
"permit-port-forwarding": "", |
||||||
|
"permit-pty": "", |
||||||
|
"permit-user-rc": "", |
||||||
|
} |
||||||
|
authKeyBytes := []byte(exampleSSHCertWithOptions) |
||||||
|
|
||||||
|
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||||
|
} |
||||||
|
if len(rest) > 0 { |
||||||
|
t.Errorf("rest: got %q, want empty", rest) |
||||||
|
} |
||||||
|
cert, ok := key.(*Certificate) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||||
|
} |
||||||
|
if !reflect.DeepEqual(cert.CriticalOptions, opts) { |
||||||
|
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) |
||||||
|
} |
||||||
|
if !reflect.DeepEqual(cert.Extensions, exts) { |
||||||
|
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) |
||||||
|
} |
||||||
|
marshaled := MarshalAuthorizedKey(key) |
||||||
|
// Before comparison, remove the trailing newline that
|
||||||
|
// MarshalAuthorizedKey adds.
|
||||||
|
marshaled = marshaled[:len(marshaled)-1] |
||||||
|
if !bytes.Equal(authKeyBytes, marshaled) { |
||||||
|
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateCert(t *testing.T) { |
||||||
|
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||||
|
} |
||||||
|
validCert, ok := key.(*Certificate) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||||
|
} |
||||||
|
checker := CertChecker{} |
||||||
|
checker.IsAuthority = func(k PublicKey) bool { |
||||||
|
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) |
||||||
|
} |
||||||
|
|
||||||
|
if err := checker.CheckCert("user", validCert); err != nil { |
||||||
|
t.Errorf("Unable to validate certificate: %v", err) |
||||||
|
} |
||||||
|
invalidCert := &Certificate{ |
||||||
|
Key: testPublicKeys["rsa"], |
||||||
|
SignatureKey: testPublicKeys["ecdsa"], |
||||||
|
ValidBefore: CertTimeInfinity, |
||||||
|
Signature: &Signature{}, |
||||||
|
} |
||||||
|
if err := checker.CheckCert("user", invalidCert); err == nil { |
||||||
|
t.Error("Invalid cert signature passed validation") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateCertTime(t *testing.T) { |
||||||
|
cert := Certificate{ |
||||||
|
ValidPrincipals: []string{"user"}, |
||||||
|
Key: testPublicKeys["rsa"], |
||||||
|
ValidAfter: 50, |
||||||
|
ValidBefore: 100, |
||||||
|
} |
||||||
|
|
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
|
||||||
|
for ts, ok := range map[int64]bool{ |
||||||
|
25: false, |
||||||
|
50: true, |
||||||
|
99: true, |
||||||
|
100: false, |
||||||
|
125: false, |
||||||
|
} { |
||||||
|
checker := CertChecker{ |
||||||
|
Clock: func() time.Time { return time.Unix(ts, 0) }, |
||||||
|
} |
||||||
|
checker.IsAuthority = func(k PublicKey) bool { |
||||||
|
return bytes.Equal(k.Marshal(), |
||||||
|
testPublicKeys["ecdsa"].Marshal()) |
||||||
|
} |
||||||
|
|
||||||
|
if v := checker.CheckCert("user", &cert); (v == nil) != ok { |
||||||
|
t.Errorf("Authenticate(%d): %v", ts, v) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// TODO(hanwen): tests for
|
||||||
|
//
|
||||||
|
// host keys:
|
||||||
|
// * fallbacks
|
||||||
|
|
||||||
|
func TestHostKeyCert(t *testing.T) { |
||||||
|
cert := &Certificate{ |
||||||
|
ValidPrincipals: []string{"hostname", "hostname.domain"}, |
||||||
|
Key: testPublicKeys["rsa"], |
||||||
|
ValidBefore: CertTimeInfinity, |
||||||
|
CertType: HostCert, |
||||||
|
} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
|
||||||
|
checker := &CertChecker{ |
||||||
|
IsAuthority: func(p PublicKey) bool { |
||||||
|
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("NewCertSigner: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
for _, name := range []string{"hostname", "otherhost"} { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
errc := make(chan error) |
||||||
|
|
||||||
|
go func() { |
||||||
|
conf := ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
conf.AddHostKey(certSigner) |
||||||
|
_, _, _, err := NewServerConn(c1, &conf) |
||||||
|
errc <- err |
||||||
|
}() |
||||||
|
|
||||||
|
config := &ClientConfig{ |
||||||
|
User: "user", |
||||||
|
HostKeyCallback: checker.CheckHostKey, |
||||||
|
} |
||||||
|
_, _, _, err = NewClientConn(c2, name, config) |
||||||
|
|
||||||
|
succeed := name == "hostname" |
||||||
|
if (err == nil) != succeed { |
||||||
|
t.Fatalf("NewClientConn(%q): %v", name, err) |
||||||
|
} |
||||||
|
|
||||||
|
err = <-errc |
||||||
|
if (err == nil) != succeed { |
||||||
|
t.Fatalf("NewServerConn(%q): %v", name, err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,631 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"log" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
minPacketLength = 9 |
||||||
|
// channelMaxPacket contains the maximum number of bytes that will be
|
||||||
|
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||||
|
// the minimum.
|
||||||
|
channelMaxPacket = 1 << 15 |
||||||
|
// We follow OpenSSH here.
|
||||||
|
channelWindowSize = 64 * channelMaxPacket |
||||||
|
) |
||||||
|
|
||||||
|
// NewChannel represents an incoming request to a channel. It must either be
|
||||||
|
// accepted for use by calling Accept, or rejected by calling Reject.
|
||||||
|
type NewChannel interface { |
||||||
|
// Accept accepts the channel creation request. It returns the Channel
|
||||||
|
// and a Go channel containing SSH requests. The Go channel must be
|
||||||
|
// serviced otherwise the Channel will hang.
|
||||||
|
Accept() (Channel, <-chan *Request, error) |
||||||
|
|
||||||
|
// Reject rejects the channel creation request. After calling
|
||||||
|
// this, no other methods on the Channel may be called.
|
||||||
|
Reject(reason RejectionReason, message string) error |
||||||
|
|
||||||
|
// ChannelType returns the type of the channel, as supplied by the
|
||||||
|
// client.
|
||||||
|
ChannelType() string |
||||||
|
|
||||||
|
// ExtraData returns the arbitrary payload for this channel, as supplied
|
||||||
|
// by the client. This data is specific to the channel type.
|
||||||
|
ExtraData() []byte |
||||||
|
} |
||||||
|
|
||||||
|
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
||||||
|
// that is multiplexed over an SSH connection.
|
||||||
|
type Channel interface { |
||||||
|
// Read reads up to len(data) bytes from the channel.
|
||||||
|
Read(data []byte) (int, error) |
||||||
|
|
||||||
|
// Write writes len(data) bytes to the channel.
|
||||||
|
Write(data []byte) (int, error) |
||||||
|
|
||||||
|
// Close signals end of channel use. No data may be sent after this
|
||||||
|
// call.
|
||||||
|
Close() error |
||||||
|
|
||||||
|
// CloseWrite signals the end of sending in-band
|
||||||
|
// data. Requests may still be sent, and the other side may
|
||||||
|
// still send data
|
||||||
|
CloseWrite() error |
||||||
|
|
||||||
|
// SendRequest sends a channel request. If wantReply is true,
|
||||||
|
// it will wait for a reply and return the result as a
|
||||||
|
// boolean, otherwise the return value will be false. Channel
|
||||||
|
// requests are out-of-band messages so they may be sent even
|
||||||
|
// if the data stream is closed or blocked by flow control.
|
||||||
|
SendRequest(name string, wantReply bool, payload []byte) (bool, error) |
||||||
|
|
||||||
|
// Stderr returns an io.ReadWriter that writes to this channel
|
||||||
|
// with the extended data type set to stderr. Stderr may
|
||||||
|
// safely be read and written from a different goroutine than
|
||||||
|
// Read and Write respectively.
|
||||||
|
Stderr() io.ReadWriter |
||||||
|
} |
||||||
|
|
||||||
|
// Request is a request sent outside of the normal stream of
|
||||||
|
// data. Requests can either be specific to an SSH channel, or they
|
||||||
|
// can be global.
|
||||||
|
type Request struct { |
||||||
|
Type string |
||||||
|
WantReply bool |
||||||
|
Payload []byte |
||||||
|
|
||||||
|
ch *channel |
||||||
|
mux *mux |
||||||
|
} |
||||||
|
|
||||||
|
// Reply sends a response to a request. It must be called for all requests
|
||||||
|
// where WantReply is true and is a no-op otherwise. The payload argument is
|
||||||
|
// ignored for replies to channel-specific requests.
|
||||||
|
func (r *Request) Reply(ok bool, payload []byte) error { |
||||||
|
if !r.WantReply { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
if r.ch == nil { |
||||||
|
return r.mux.ackRequest(ok, payload) |
||||||
|
} |
||||||
|
|
||||||
|
return r.ch.ackRequest(ok) |
||||||
|
} |
||||||
|
|
||||||
|
// RejectionReason is an enumeration used when rejecting channel creation
|
||||||
|
// requests. See RFC 4254, section 5.1.
|
||||||
|
type RejectionReason uint32 |
||||||
|
|
||||||
|
const ( |
||||||
|
Prohibited RejectionReason = iota + 1 |
||||||
|
ConnectionFailed |
||||||
|
UnknownChannelType |
||||||
|
ResourceShortage |
||||||
|
) |
||||||
|
|
||||||
|
// String converts the rejection reason to human readable form.
|
||||||
|
func (r RejectionReason) String() string { |
||||||
|
switch r { |
||||||
|
case Prohibited: |
||||||
|
return "administratively prohibited" |
||||||
|
case ConnectionFailed: |
||||||
|
return "connect failed" |
||||||
|
case UnknownChannelType: |
||||||
|
return "unknown channel type" |
||||||
|
case ResourceShortage: |
||||||
|
return "resource shortage" |
||||||
|
} |
||||||
|
return fmt.Sprintf("unknown reason %d", int(r)) |
||||||
|
} |
||||||
|
|
||||||
|
func min(a uint32, b int) uint32 { |
||||||
|
if a < uint32(b) { |
||||||
|
return a |
||||||
|
} |
||||||
|
return uint32(b) |
||||||
|
} |
||||||
|
|
||||||
|
type channelDirection uint8 |
||||||
|
|
||||||
|
const ( |
||||||
|
channelInbound channelDirection = iota |
||||||
|
channelOutbound |
||||||
|
) |
||||||
|
|
||||||
|
// channel is an implementation of the Channel interface that works
|
||||||
|
// with the mux class.
|
||||||
|
type channel struct { |
||||||
|
// R/O after creation
|
||||||
|
chanType string |
||||||
|
extraData []byte |
||||||
|
localId, remoteId uint32 |
||||||
|
|
||||||
|
// maxIncomingPayload and maxRemotePayload are the maximum
|
||||||
|
// payload sizes of normal and extended data packets for
|
||||||
|
// receiving and sending, respectively. The wire packet will
|
||||||
|
// be 9 or 13 bytes larger (excluding encryption overhead).
|
||||||
|
maxIncomingPayload uint32 |
||||||
|
maxRemotePayload uint32 |
||||||
|
|
||||||
|
mux *mux |
||||||
|
|
||||||
|
// decided is set to true if an accept or reject message has been sent
|
||||||
|
// (for outbound channels) or received (for inbound channels).
|
||||||
|
decided bool |
||||||
|
|
||||||
|
// direction contains either channelOutbound, for channels created
|
||||||
|
// locally, or channelInbound, for channels created by the peer.
|
||||||
|
direction channelDirection |
||||||
|
|
||||||
|
// Pending internal channel messages.
|
||||||
|
msg chan interface{} |
||||||
|
|
||||||
|
// Since requests have no ID, there can be only one request
|
||||||
|
// with WantReply=true outstanding. This lock is held by a
|
||||||
|
// goroutine that has such an outgoing request pending.
|
||||||
|
sentRequestMu sync.Mutex |
||||||
|
|
||||||
|
incomingRequests chan *Request |
||||||
|
|
||||||
|
sentEOF bool |
||||||
|
|
||||||
|
// thread-safe data
|
||||||
|
remoteWin window |
||||||
|
pending *buffer |
||||||
|
extPending *buffer |
||||||
|
|
||||||
|
// windowMu protects myWindow, the flow-control window.
|
||||||
|
windowMu sync.Mutex |
||||||
|
myWindow uint32 |
||||||
|
|
||||||
|
// writeMu serializes calls to mux.conn.writePacket() and
|
||||||
|
// protects sentClose and packetPool. This mutex must be
|
||||||
|
// different from windowMu, as writePacket can block if there
|
||||||
|
// is a key exchange pending.
|
||||||
|
writeMu sync.Mutex |
||||||
|
sentClose bool |
||||||
|
|
||||||
|
// packetPool has a buffer for each extended channel ID to
|
||||||
|
// save allocations during writes.
|
||||||
|
packetPool map[uint32][]byte |
||||||
|
} |
||||||
|
|
||||||
|
// writePacket sends a packet. If the packet is a channel close, it updates
|
||||||
|
// sentClose. This method takes the lock c.writeMu.
|
||||||
|
func (c *channel) writePacket(packet []byte) error { |
||||||
|
c.writeMu.Lock() |
||||||
|
if c.sentClose { |
||||||
|
c.writeMu.Unlock() |
||||||
|
return io.EOF |
||||||
|
} |
||||||
|
c.sentClose = (packet[0] == msgChannelClose) |
||||||
|
err := c.mux.conn.writePacket(packet) |
||||||
|
c.writeMu.Unlock() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) sendMessage(msg interface{}) error { |
||||||
|
if debugMux { |
||||||
|
log.Printf("send %d: %#v", c.mux.chanList.offset, msg) |
||||||
|
} |
||||||
|
|
||||||
|
p := Marshal(msg) |
||||||
|
binary.BigEndian.PutUint32(p[1:], c.remoteId) |
||||||
|
return c.writePacket(p) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteExtended writes data to a specific extended stream. These streams are
|
||||||
|
// used, for example, for stderr.
|
||||||
|
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { |
||||||
|
if c.sentEOF { |
||||||
|
return 0, io.EOF |
||||||
|
} |
||||||
|
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
||||||
|
opCode := byte(msgChannelData) |
||||||
|
headerLength := uint32(9) |
||||||
|
if extendedCode > 0 { |
||||||
|
headerLength += 4 |
||||||
|
opCode = msgChannelExtendedData |
||||||
|
} |
||||||
|
|
||||||
|
c.writeMu.Lock() |
||||||
|
packet := c.packetPool[extendedCode] |
||||||
|
// We don't remove the buffer from packetPool, so
|
||||||
|
// WriteExtended calls from different goroutines will be
|
||||||
|
// flagged as errors by the race detector.
|
||||||
|
c.writeMu.Unlock() |
||||||
|
|
||||||
|
for len(data) > 0 { |
||||||
|
space := min(c.maxRemotePayload, len(data)) |
||||||
|
if space, err = c.remoteWin.reserve(space); err != nil { |
||||||
|
return n, err |
||||||
|
} |
||||||
|
if want := headerLength + space; uint32(cap(packet)) < want { |
||||||
|
packet = make([]byte, want) |
||||||
|
} else { |
||||||
|
packet = packet[:want] |
||||||
|
} |
||||||
|
|
||||||
|
todo := data[:space] |
||||||
|
|
||||||
|
packet[0] = opCode |
||||||
|
binary.BigEndian.PutUint32(packet[1:], c.remoteId) |
||||||
|
if extendedCode > 0 { |
||||||
|
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) |
||||||
|
} |
||||||
|
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) |
||||||
|
copy(packet[headerLength:], todo) |
||||||
|
if err = c.writePacket(packet); err != nil { |
||||||
|
return n, err |
||||||
|
} |
||||||
|
|
||||||
|
n += len(todo) |
||||||
|
data = data[len(todo):] |
||||||
|
} |
||||||
|
|
||||||
|
c.writeMu.Lock() |
||||||
|
c.packetPool[extendedCode] = packet |
||||||
|
c.writeMu.Unlock() |
||||||
|
|
||||||
|
return n, err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) handleData(packet []byte) error { |
||||||
|
headerLen := 9 |
||||||
|
isExtendedData := packet[0] == msgChannelExtendedData |
||||||
|
if isExtendedData { |
||||||
|
headerLen = 13 |
||||||
|
} |
||||||
|
if len(packet) < headerLen { |
||||||
|
// malformed data packet
|
||||||
|
return parseError(packet[0]) |
||||||
|
} |
||||||
|
|
||||||
|
var extended uint32 |
||||||
|
if isExtendedData { |
||||||
|
extended = binary.BigEndian.Uint32(packet[5:]) |
||||||
|
} |
||||||
|
|
||||||
|
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) |
||||||
|
if length == 0 { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if length > c.maxIncomingPayload { |
||||||
|
// TODO(hanwen): should send Disconnect?
|
||||||
|
return errors.New("ssh: incoming packet exceeds maximum payload size") |
||||||
|
} |
||||||
|
|
||||||
|
data := packet[headerLen:] |
||||||
|
if length != uint32(len(data)) { |
||||||
|
return errors.New("ssh: wrong packet length") |
||||||
|
} |
||||||
|
|
||||||
|
c.windowMu.Lock() |
||||||
|
if c.myWindow < length { |
||||||
|
c.windowMu.Unlock() |
||||||
|
// TODO(hanwen): should send Disconnect with reason?
|
||||||
|
return errors.New("ssh: remote side wrote too much") |
||||||
|
} |
||||||
|
c.myWindow -= length |
||||||
|
c.windowMu.Unlock() |
||||||
|
|
||||||
|
if extended == 1 { |
||||||
|
c.extPending.write(data) |
||||||
|
} else if extended > 0 { |
||||||
|
// discard other extended data.
|
||||||
|
} else { |
||||||
|
c.pending.write(data) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) adjustWindow(n uint32) error { |
||||||
|
c.windowMu.Lock() |
||||||
|
// Since myWindow is managed on our side, and can never exceed
|
||||||
|
// the initial window setting, we don't worry about overflow.
|
||||||
|
c.myWindow += uint32(n) |
||||||
|
c.windowMu.Unlock() |
||||||
|
return c.sendMessage(windowAdjustMsg{ |
||||||
|
AdditionalBytes: uint32(n), |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { |
||||||
|
switch extended { |
||||||
|
case 1: |
||||||
|
n, err = c.extPending.Read(data) |
||||||
|
case 0: |
||||||
|
n, err = c.pending.Read(data) |
||||||
|
default: |
||||||
|
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) |
||||||
|
} |
||||||
|
|
||||||
|
if n > 0 { |
||||||
|
err = c.adjustWindow(uint32(n)) |
||||||
|
// sendWindowAdjust can return io.EOF if the remote
|
||||||
|
// peer has closed the connection, however we want to
|
||||||
|
// defer forwarding io.EOF to the caller of Read until
|
||||||
|
// the buffer has been drained.
|
||||||
|
if n > 0 && err == io.EOF { |
||||||
|
err = nil |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return n, err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) close() { |
||||||
|
c.pending.eof() |
||||||
|
c.extPending.eof() |
||||||
|
close(c.msg) |
||||||
|
close(c.incomingRequests) |
||||||
|
c.writeMu.Lock() |
||||||
|
// This is not necesary for a normal channel teardown, but if
|
||||||
|
// there was another error, it is.
|
||||||
|
c.sentClose = true |
||||||
|
c.writeMu.Unlock() |
||||||
|
// Unblock writers.
|
||||||
|
c.remoteWin.close() |
||||||
|
} |
||||||
|
|
||||||
|
// responseMessageReceived is called when a success or failure message is
|
||||||
|
// received on a channel to check that such a message is reasonable for the
|
||||||
|
// given channel.
|
||||||
|
func (c *channel) responseMessageReceived() error { |
||||||
|
if c.direction == channelInbound { |
||||||
|
return errors.New("ssh: channel response message received on inbound channel") |
||||||
|
} |
||||||
|
if c.decided { |
||||||
|
return errors.New("ssh: duplicate response received for channel") |
||||||
|
} |
||||||
|
c.decided = true |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) handlePacket(packet []byte) error { |
||||||
|
switch packet[0] { |
||||||
|
case msgChannelData, msgChannelExtendedData: |
||||||
|
return c.handleData(packet) |
||||||
|
case msgChannelClose: |
||||||
|
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) |
||||||
|
c.mux.chanList.remove(c.localId) |
||||||
|
c.close() |
||||||
|
return nil |
||||||
|
case msgChannelEOF: |
||||||
|
// RFC 4254 is mute on how EOF affects dataExt messages but
|
||||||
|
// it is logical to signal EOF at the same time.
|
||||||
|
c.extPending.eof() |
||||||
|
c.pending.eof() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
decoded, err := decode(packet) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
switch msg := decoded.(type) { |
||||||
|
case *channelOpenFailureMsg: |
||||||
|
if err := c.responseMessageReceived(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
c.mux.chanList.remove(msg.PeersId) |
||||||
|
c.msg <- msg |
||||||
|
case *channelOpenConfirmMsg: |
||||||
|
if err := c.responseMessageReceived(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||||
|
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) |
||||||
|
} |
||||||
|
c.remoteId = msg.MyId |
||||||
|
c.maxRemotePayload = msg.MaxPacketSize |
||||||
|
c.remoteWin.add(msg.MyWindow) |
||||||
|
c.msg <- msg |
||||||
|
case *windowAdjustMsg: |
||||||
|
if !c.remoteWin.add(msg.AdditionalBytes) { |
||||||
|
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) |
||||||
|
} |
||||||
|
case *channelRequestMsg: |
||||||
|
req := Request{ |
||||||
|
Type: msg.Request, |
||||||
|
WantReply: msg.WantReply, |
||||||
|
Payload: msg.RequestSpecificData, |
||||||
|
ch: c, |
||||||
|
} |
||||||
|
|
||||||
|
c.incomingRequests <- &req |
||||||
|
default: |
||||||
|
c.msg <- msg |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { |
||||||
|
ch := &channel{ |
||||||
|
remoteWin: window{Cond: newCond()}, |
||||||
|
myWindow: channelWindowSize, |
||||||
|
pending: newBuffer(), |
||||||
|
extPending: newBuffer(), |
||||||
|
direction: direction, |
||||||
|
incomingRequests: make(chan *Request, 16), |
||||||
|
msg: make(chan interface{}, 16), |
||||||
|
chanType: chanType, |
||||||
|
extraData: extraData, |
||||||
|
mux: m, |
||||||
|
packetPool: make(map[uint32][]byte), |
||||||
|
} |
||||||
|
ch.localId = m.chanList.add(ch) |
||||||
|
return ch |
||||||
|
} |
||||||
|
|
||||||
|
var errUndecided = errors.New("ssh: must Accept or Reject channel") |
||||||
|
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") |
||||||
|
|
||||||
|
type extChannel struct { |
||||||
|
code uint32 |
||||||
|
ch *channel |
||||||
|
} |
||||||
|
|
||||||
|
func (e *extChannel) Write(data []byte) (n int, err error) { |
||||||
|
return e.ch.WriteExtended(data, e.code) |
||||||
|
} |
||||||
|
|
||||||
|
func (e *extChannel) Read(data []byte) (n int, err error) { |
||||||
|
return e.ch.ReadExtended(data, e.code) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *channel) Accept() (Channel, <-chan *Request, error) { |
||||||
|
if c.decided { |
||||||
|
return nil, nil, errDecidedAlready |
||||||
|
} |
||||||
|
c.maxIncomingPayload = channelMaxPacket |
||||||
|
confirm := channelOpenConfirmMsg{ |
||||||
|
PeersId: c.remoteId, |
||||||
|
MyId: c.localId, |
||||||
|
MyWindow: c.myWindow, |
||||||
|
MaxPacketSize: c.maxIncomingPayload, |
||||||
|
} |
||||||
|
c.decided = true |
||||||
|
if err := c.sendMessage(confirm); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return c, c.incomingRequests, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) Reject(reason RejectionReason, message string) error { |
||||||
|
if ch.decided { |
||||||
|
return errDecidedAlready |
||||||
|
} |
||||||
|
reject := channelOpenFailureMsg{ |
||||||
|
PeersId: ch.remoteId, |
||||||
|
Reason: reason, |
||||||
|
Message: message, |
||||||
|
Language: "en", |
||||||
|
} |
||||||
|
ch.decided = true |
||||||
|
return ch.sendMessage(reject) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) Read(data []byte) (int, error) { |
||||||
|
if !ch.decided { |
||||||
|
return 0, errUndecided |
||||||
|
} |
||||||
|
return ch.ReadExtended(data, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) Write(data []byte) (int, error) { |
||||||
|
if !ch.decided { |
||||||
|
return 0, errUndecided |
||||||
|
} |
||||||
|
return ch.WriteExtended(data, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) CloseWrite() error { |
||||||
|
if !ch.decided { |
||||||
|
return errUndecided |
||||||
|
} |
||||||
|
ch.sentEOF = true |
||||||
|
return ch.sendMessage(channelEOFMsg{ |
||||||
|
PeersId: ch.remoteId}) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) Close() error { |
||||||
|
if !ch.decided { |
||||||
|
return errUndecided |
||||||
|
} |
||||||
|
|
||||||
|
return ch.sendMessage(channelCloseMsg{ |
||||||
|
PeersId: ch.remoteId}) |
||||||
|
} |
||||||
|
|
||||||
|
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
||||||
|
// SSH extended stream. Such streams are used, for example, for stderr.
|
||||||
|
func (ch *channel) Extended(code uint32) io.ReadWriter { |
||||||
|
if !ch.decided { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return &extChannel{code, ch} |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) Stderr() io.ReadWriter { |
||||||
|
return ch.Extended(1) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||||
|
if !ch.decided { |
||||||
|
return false, errUndecided |
||||||
|
} |
||||||
|
|
||||||
|
if wantReply { |
||||||
|
ch.sentRequestMu.Lock() |
||||||
|
defer ch.sentRequestMu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
msg := channelRequestMsg{ |
||||||
|
PeersId: ch.remoteId, |
||||||
|
Request: name, |
||||||
|
WantReply: wantReply, |
||||||
|
RequestSpecificData: payload, |
||||||
|
} |
||||||
|
|
||||||
|
if err := ch.sendMessage(msg); err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
|
||||||
|
if wantReply { |
||||||
|
m, ok := (<-ch.msg) |
||||||
|
if !ok { |
||||||
|
return false, io.EOF |
||||||
|
} |
||||||
|
switch m.(type) { |
||||||
|
case *channelRequestFailureMsg: |
||||||
|
return false, nil |
||||||
|
case *channelRequestSuccessMsg: |
||||||
|
return true, nil |
||||||
|
default: |
||||||
|
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return false, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ackRequest either sends an ack or nack to the channel request.
|
||||||
|
func (ch *channel) ackRequest(ok bool) error { |
||||||
|
if !ch.decided { |
||||||
|
return errUndecided |
||||||
|
} |
||||||
|
|
||||||
|
var msg interface{} |
||||||
|
if !ok { |
||||||
|
msg = channelRequestFailureMsg{ |
||||||
|
PeersId: ch.remoteId, |
||||||
|
} |
||||||
|
} else { |
||||||
|
msg = channelRequestSuccessMsg{ |
||||||
|
PeersId: ch.remoteId, |
||||||
|
} |
||||||
|
} |
||||||
|
return ch.sendMessage(msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) ChannelType() string { |
||||||
|
return ch.chanType |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *channel) ExtraData() []byte { |
||||||
|
return ch.extraData |
||||||
|
} |
@ -0,0 +1,549 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/aes" |
||||||
|
"crypto/cipher" |
||||||
|
"crypto/rc4" |
||||||
|
"crypto/subtle" |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"hash" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||||
|
|
||||||
|
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
||||||
|
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
||||||
|
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
||||||
|
// waffles on about reasonable limits.
|
||||||
|
//
|
||||||
|
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
||||||
|
// the same. maxPacket is also used to ensure that uint32
|
||||||
|
// length fields do not overflow, so it should remain well
|
||||||
|
// below 4G.
|
||||||
|
maxPacket = 256 * 1024 |
||||||
|
) |
||||||
|
|
||||||
|
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||||
|
// by the transport before the first key-exchange.
|
||||||
|
type noneCipher struct{} |
||||||
|
|
||||||
|
func (c noneCipher) XORKeyStream(dst, src []byte) { |
||||||
|
copy(dst, src) |
||||||
|
} |
||||||
|
|
||||||
|
func newAESCTR(key, iv []byte) (cipher.Stream, error) { |
||||||
|
c, err := aes.NewCipher(key) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return cipher.NewCTR(c, iv), nil |
||||||
|
} |
||||||
|
|
||||||
|
func newRC4(key, iv []byte) (cipher.Stream, error) { |
||||||
|
return rc4.NewCipher(key) |
||||||
|
} |
||||||
|
|
||||||
|
type streamCipherMode struct { |
||||||
|
keySize int |
||||||
|
ivSize int |
||||||
|
skip int |
||||||
|
createFunc func(key, iv []byte) (cipher.Stream, error) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { |
||||||
|
if len(key) < c.keySize { |
||||||
|
panic("ssh: key length too small for cipher") |
||||||
|
} |
||||||
|
if len(iv) < c.ivSize { |
||||||
|
panic("ssh: iv too small for cipher") |
||||||
|
} |
||||||
|
|
||||||
|
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var streamDump []byte |
||||||
|
if c.skip > 0 { |
||||||
|
streamDump = make([]byte, 512) |
||||||
|
} |
||||||
|
|
||||||
|
for remainingToDump := c.skip; remainingToDump > 0; { |
||||||
|
dumpThisTime := remainingToDump |
||||||
|
if dumpThisTime > len(streamDump) { |
||||||
|
dumpThisTime = len(streamDump) |
||||||
|
} |
||||||
|
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) |
||||||
|
remainingToDump -= dumpThisTime |
||||||
|
} |
||||||
|
|
||||||
|
return stream, nil |
||||||
|
} |
||||||
|
|
||||||
|
// cipherModes documents properties of supported ciphers. Ciphers not included
|
||||||
|
// are not supported and will not be negotiated, even if explicitly requested in
|
||||||
|
// ClientConfig.Crypto.Ciphers.
|
||||||
|
var cipherModes = map[string]*streamCipherMode{ |
||||||
|
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
||||||
|
// are defined in the order specified in the RFC.
|
||||||
|
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, |
||||||
|
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, |
||||||
|
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, |
||||||
|
|
||||||
|
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
||||||
|
// They are defined in the order specified in the RFC.
|
||||||
|
"arcfour128": {16, 0, 1536, newRC4}, |
||||||
|
"arcfour256": {32, 0, 1536, newRC4}, |
||||||
|
|
||||||
|
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
||||||
|
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
||||||
|
// RC4) has problems with weak keys, and should be used with caution."
|
||||||
|
// RFC4345 introduces improved versions of Arcfour.
|
||||||
|
"arcfour": {16, 0, 0, newRC4}, |
||||||
|
|
||||||
|
// AES-GCM is not a stream cipher, so it is constructed with a
|
||||||
|
// special case. If we add any more non-stream ciphers, we
|
||||||
|
// should invest a cleaner way to do this.
|
||||||
|
gcmCipherID: {16, 12, 0, nil}, |
||||||
|
|
||||||
|
// insecure cipher, see http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf
|
||||||
|
// uncomment below to enable it.
|
||||||
|
// aes128cbcID: {16, aes.BlockSize, 0, nil},
|
||||||
|
} |
||||||
|
|
||||||
|
// prefixLen is the length of the packet prefix that contains the packet length
|
||||||
|
// and number of padding bytes.
|
||||||
|
const prefixLen = 5 |
||||||
|
|
||||||
|
// streamPacketCipher is a packetCipher using a stream cipher.
|
||||||
|
type streamPacketCipher struct { |
||||||
|
mac hash.Hash |
||||||
|
cipher cipher.Stream |
||||||
|
|
||||||
|
// The following members are to avoid per-packet allocations.
|
||||||
|
prefix [prefixLen]byte |
||||||
|
seqNumBytes [4]byte |
||||||
|
padding [2 * packetSizeMultiple]byte |
||||||
|
packetData []byte |
||||||
|
macResult []byte |
||||||
|
} |
||||||
|
|
||||||
|
// readPacket reads and decrypt a single packet from the reader argument.
|
||||||
|
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||||
|
if _, err := io.ReadFull(r, s.prefix[:]); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||||
|
length := binary.BigEndian.Uint32(s.prefix[0:4]) |
||||||
|
paddingLength := uint32(s.prefix[4]) |
||||||
|
|
||||||
|
var macSize uint32 |
||||||
|
if s.mac != nil { |
||||||
|
s.mac.Reset() |
||||||
|
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||||
|
s.mac.Write(s.seqNumBytes[:]) |
||||||
|
s.mac.Write(s.prefix[:]) |
||||||
|
macSize = uint32(s.mac.Size()) |
||||||
|
} |
||||||
|
|
||||||
|
if length <= paddingLength+1 { |
||||||
|
return nil, errors.New("ssh: invalid packet length, packet too small") |
||||||
|
} |
||||||
|
|
||||||
|
if length > maxPacket { |
||||||
|
return nil, errors.New("ssh: invalid packet length, packet too large") |
||||||
|
} |
||||||
|
|
||||||
|
// the maxPacket check above ensures that length-1+macSize
|
||||||
|
// does not overflow.
|
||||||
|
if uint32(cap(s.packetData)) < length-1+macSize { |
||||||
|
s.packetData = make([]byte, length-1+macSize) |
||||||
|
} else { |
||||||
|
s.packetData = s.packetData[:length-1+macSize] |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := io.ReadFull(r, s.packetData); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
mac := s.packetData[length-1:] |
||||||
|
data := s.packetData[:length-1] |
||||||
|
s.cipher.XORKeyStream(data, data) |
||||||
|
|
||||||
|
if s.mac != nil { |
||||||
|
s.mac.Write(data) |
||||||
|
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||||
|
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { |
||||||
|
return nil, errors.New("ssh: MAC failure") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return s.packetData[:length-paddingLength-1], nil |
||||||
|
} |
||||||
|
|
||||||
|
// writePacket encrypts and sends a packet of data to the writer argument
|
||||||
|
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||||
|
if len(packet) > maxPacket { |
||||||
|
return errors.New("ssh: packet too large") |
||||||
|
} |
||||||
|
|
||||||
|
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple |
||||||
|
if paddingLength < 4 { |
||||||
|
paddingLength += packetSizeMultiple |
||||||
|
} |
||||||
|
|
||||||
|
length := len(packet) + 1 + paddingLength |
||||||
|
binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) |
||||||
|
s.prefix[4] = byte(paddingLength) |
||||||
|
padding := s.padding[:paddingLength] |
||||||
|
if _, err := io.ReadFull(rand, padding); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if s.mac != nil { |
||||||
|
s.mac.Reset() |
||||||
|
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||||
|
s.mac.Write(s.seqNumBytes[:]) |
||||||
|
s.mac.Write(s.prefix[:]) |
||||||
|
s.mac.Write(packet) |
||||||
|
s.mac.Write(padding) |
||||||
|
} |
||||||
|
|
||||||
|
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||||
|
s.cipher.XORKeyStream(packet, packet) |
||||||
|
s.cipher.XORKeyStream(padding, padding) |
||||||
|
|
||||||
|
if _, err := w.Write(s.prefix[:]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, err := w.Write(packet); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if _, err := w.Write(padding); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if s.mac != nil { |
||||||
|
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||||
|
if _, err := w.Write(s.macResult); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
type gcmCipher struct { |
||||||
|
aead cipher.AEAD |
||||||
|
prefix [4]byte |
||||||
|
iv []byte |
||||||
|
buf []byte |
||||||
|
} |
||||||
|
|
||||||
|
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { |
||||||
|
c, err := aes.NewCipher(key) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
aead, err := cipher.NewGCM(c) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &gcmCipher{ |
||||||
|
aead: aead, |
||||||
|
iv: iv, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
const gcmTagSize = 16 |
||||||
|
|
||||||
|
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||||
|
// Pad out to multiple of 16 bytes. This is different from the
|
||||||
|
// stream cipher because that encrypts the length too.
|
||||||
|
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) |
||||||
|
if padding < 4 { |
||||||
|
padding += packetSizeMultiple |
||||||
|
} |
||||||
|
|
||||||
|
length := uint32(len(packet) + int(padding) + 1) |
||||||
|
binary.BigEndian.PutUint32(c.prefix[:], length) |
||||||
|
if _, err := w.Write(c.prefix[:]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if cap(c.buf) < int(length) { |
||||||
|
c.buf = make([]byte, length) |
||||||
|
} else { |
||||||
|
c.buf = c.buf[:length] |
||||||
|
} |
||||||
|
|
||||||
|
c.buf[0] = padding |
||||||
|
copy(c.buf[1:], packet) |
||||||
|
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||||
|
if _, err := w.Write(c.buf); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
c.incIV() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *gcmCipher) incIV() { |
||||||
|
for i := 4 + 7; i >= 4; i-- { |
||||||
|
c.iv[i]++ |
||||||
|
if c.iv[i] != 0 { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||||
|
if _, err := io.ReadFull(r, c.prefix[:]); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
length := binary.BigEndian.Uint32(c.prefix[:]) |
||||||
|
if length > maxPacket { |
||||||
|
return nil, errors.New("ssh: max packet length exceeded.") |
||||||
|
} |
||||||
|
|
||||||
|
if cap(c.buf) < int(length+gcmTagSize) { |
||||||
|
c.buf = make([]byte, length+gcmTagSize) |
||||||
|
} else { |
||||||
|
c.buf = c.buf[:length+gcmTagSize] |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := io.ReadFull(r, c.buf); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
c.incIV() |
||||||
|
|
||||||
|
padding := plain[0] |
||||||
|
if padding < 4 || padding >= 20 { |
||||||
|
return nil, fmt.Errorf("ssh: illegal padding %d", padding) |
||||||
|
} |
||||||
|
|
||||||
|
if int(padding+1) >= len(plain) { |
||||||
|
return nil, fmt.Errorf("ssh: padding %d too large", padding) |
||||||
|
} |
||||||
|
plain = plain[1 : length-uint32(padding)] |
||||||
|
return plain, nil |
||||||
|
} |
||||||
|
|
||||||
|
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
||||||
|
type cbcCipher struct { |
||||||
|
mac hash.Hash |
||||||
|
macSize uint32 |
||||||
|
decrypter cipher.BlockMode |
||||||
|
encrypter cipher.BlockMode |
||||||
|
|
||||||
|
// The following members are to avoid per-packet allocations.
|
||||||
|
seqNumBytes [4]byte |
||||||
|
packetData []byte |
||||||
|
macResult []byte |
||||||
|
|
||||||
|
// Amount of data we should still read to hide which
|
||||||
|
// verification error triggered.
|
||||||
|
oracleCamouflage uint32 |
||||||
|
} |
||||||
|
|
||||||
|
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||||
|
c, err := aes.NewCipher(key) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
cbc := &cbcCipher{ |
||||||
|
mac: macModes[algs.MAC].new(macKey), |
||||||
|
decrypter: cipher.NewCBCDecrypter(c, iv), |
||||||
|
encrypter: cipher.NewCBCEncrypter(c, iv), |
||||||
|
packetData: make([]byte, 1024), |
||||||
|
} |
||||||
|
if cbc.mac != nil { |
||||||
|
cbc.macSize = uint32(cbc.mac.Size()) |
||||||
|
} |
||||||
|
|
||||||
|
return cbc, nil |
||||||
|
} |
||||||
|
|
||||||
|
func maxUInt32(a, b int) uint32 { |
||||||
|
if a > b { |
||||||
|
return uint32(a) |
||||||
|
} |
||||||
|
return uint32(b) |
||||||
|
} |
||||||
|
|
||||||
|
const ( |
||||||
|
cbcMinPacketSizeMultiple = 8 |
||||||
|
cbcMinPacketSize = 16 |
||||||
|
cbcMinPaddingSize = 4 |
||||||
|
) |
||||||
|
|
||||||
|
// cbcError represents a verification error that may leak information.
|
||||||
|
type cbcError string |
||||||
|
|
||||||
|
func (e cbcError) Error() string { return string(e) } |
||||||
|
|
||||||
|
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||||
|
p, err := c.readPacketLeaky(seqNum, r) |
||||||
|
if err != nil { |
||||||
|
if _, ok := err.(cbcError); ok { |
||||||
|
// Verification error: read a fixed amount of
|
||||||
|
// data, to make distinguishing between
|
||||||
|
// failing MAC and failing length check more
|
||||||
|
// difficult.
|
||||||
|
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) |
||||||
|
} |
||||||
|
} |
||||||
|
return p, err |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { |
||||||
|
blockSize := c.decrypter.BlockSize() |
||||||
|
|
||||||
|
// Read the header, which will include some of the subsequent data in the
|
||||||
|
// case of block ciphers - this is copied back to the payload later.
|
||||||
|
// How many bytes of payload/padding will be read with this first read.
|
||||||
|
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) |
||||||
|
firstBlock := c.packetData[:firstBlockLength] |
||||||
|
if _, err := io.ReadFull(r, firstBlock); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength |
||||||
|
|
||||||
|
c.decrypter.CryptBlocks(firstBlock, firstBlock) |
||||||
|
length := binary.BigEndian.Uint32(firstBlock[:4]) |
||||||
|
if length > maxPacket { |
||||||
|
return nil, cbcError("ssh: packet too large") |
||||||
|
} |
||||||
|
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { |
||||||
|
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
||||||
|
// is larger) bytes.
|
||||||
|
return nil, cbcError("ssh: packet too small") |
||||||
|
} |
||||||
|
// The length of the packet (including the length field but not the MAC) must
|
||||||
|
// be a multiple of the block size or 8, whichever is larger.
|
||||||
|
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { |
||||||
|
return nil, cbcError("ssh: invalid packet length multiple") |
||||||
|
} |
||||||
|
|
||||||
|
paddingLength := uint32(firstBlock[4]) |
||||||
|
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { |
||||||
|
return nil, cbcError("ssh: invalid packet length") |
||||||
|
} |
||||||
|
|
||||||
|
// Positions within the c.packetData buffer:
|
||||||
|
macStart := 4 + length |
||||||
|
paddingStart := macStart - paddingLength |
||||||
|
|
||||||
|
// Entire packet size, starting before length, ending at end of mac.
|
||||||
|
entirePacketSize := macStart + c.macSize |
||||||
|
|
||||||
|
// Ensure c.packetData is large enough for the entire packet data.
|
||||||
|
if uint32(cap(c.packetData)) < entirePacketSize { |
||||||
|
// Still need to upsize and copy, but this should be rare at runtime, only
|
||||||
|
// on upsizing the packetData buffer.
|
||||||
|
c.packetData = make([]byte, entirePacketSize) |
||||||
|
copy(c.packetData, firstBlock) |
||||||
|
} else { |
||||||
|
c.packetData = c.packetData[:entirePacketSize] |
||||||
|
} |
||||||
|
|
||||||
|
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { |
||||||
|
return nil, err |
||||||
|
} else { |
||||||
|
c.oracleCamouflage -= uint32(n) |
||||||
|
} |
||||||
|
|
||||||
|
remainingCrypted := c.packetData[firstBlockLength:macStart] |
||||||
|
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) |
||||||
|
|
||||||
|
mac := c.packetData[macStart:] |
||||||
|
if c.mac != nil { |
||||||
|
c.mac.Reset() |
||||||
|
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||||
|
c.mac.Write(c.seqNumBytes[:]) |
||||||
|
c.mac.Write(c.packetData[:macStart]) |
||||||
|
c.macResult = c.mac.Sum(c.macResult[:0]) |
||||||
|
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { |
||||||
|
return nil, cbcError("ssh: MAC failure") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return c.packetData[prefixLen:paddingStart], nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||||
|
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) |
||||||
|
|
||||||
|
// Length of encrypted portion of the packet (header, payload, padding).
|
||||||
|
// Enforce minimum padding and packet size.
|
||||||
|
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) |
||||||
|
// Enforce block size.
|
||||||
|
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize |
||||||
|
|
||||||
|
length := encLength - 4 |
||||||
|
paddingLength := int(length) - (1 + len(packet)) |
||||||
|
|
||||||
|
// Overall buffer contains: header, payload, padding, mac.
|
||||||
|
// Space for the MAC is reserved in the capacity but not the slice length.
|
||||||
|
bufferSize := encLength + c.macSize |
||||||
|
if uint32(cap(c.packetData)) < bufferSize { |
||||||
|
c.packetData = make([]byte, encLength, bufferSize) |
||||||
|
} else { |
||||||
|
c.packetData = c.packetData[:encLength] |
||||||
|
} |
||||||
|
|
||||||
|
p := c.packetData |
||||||
|
|
||||||
|
// Packet header.
|
||||||
|
binary.BigEndian.PutUint32(p, length) |
||||||
|
p = p[4:] |
||||||
|
p[0] = byte(paddingLength) |
||||||
|
|
||||||
|
// Payload.
|
||||||
|
p = p[1:] |
||||||
|
copy(p, packet) |
||||||
|
|
||||||
|
// Padding.
|
||||||
|
p = p[len(packet):] |
||||||
|
if _, err := io.ReadFull(rand, p); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if c.mac != nil { |
||||||
|
c.mac.Reset() |
||||||
|
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||||
|
c.mac.Write(c.seqNumBytes[:]) |
||||||
|
c.mac.Write(c.packetData) |
||||||
|
// The MAC is now appended into the capacity reserved for it earlier.
|
||||||
|
c.packetData = c.mac.Sum(c.packetData) |
||||||
|
} |
||||||
|
|
||||||
|
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) |
||||||
|
|
||||||
|
if _, err := w.Write(c.packetData); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,127 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto" |
||||||
|
"crypto/aes" |
||||||
|
"crypto/rand" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestDefaultCiphersExist(t *testing.T) { |
||||||
|
for _, cipherAlgo := range supportedCiphers { |
||||||
|
if _, ok := cipherModes[cipherAlgo]; !ok { |
||||||
|
t.Errorf("default cipher %q is unknown", cipherAlgo) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestPacketCiphers(t *testing.T) { |
||||||
|
// Still test aes128cbc cipher althought it's commented out.
|
||||||
|
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||||
|
defer delete(cipherModes, aes128cbcID) |
||||||
|
|
||||||
|
for cipher := range cipherModes { |
||||||
|
kr := &kexResult{Hash: crypto.SHA1} |
||||||
|
algs := directionAlgorithms{ |
||||||
|
Cipher: cipher, |
||||||
|
MAC: "hmac-sha1", |
||||||
|
Compression: "none", |
||||||
|
} |
||||||
|
client, err := newPacketCipher(clientKeys, algs, kr) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||||
|
continue |
||||||
|
} |
||||||
|
server, err := newPacketCipher(clientKeys, algs, kr) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
want := "bla bla" |
||||||
|
input := []byte(want) |
||||||
|
buf := &bytes.Buffer{} |
||||||
|
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||||
|
t.Errorf("writePacket(%q): %v", cipher, err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
packet, err := server.readPacket(0, buf) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("readPacket(%q): %v", cipher, err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if string(packet) != want { |
||||||
|
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCBCOracleCounterMeasure(t *testing.T) { |
||||||
|
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||||
|
defer delete(cipherModes, aes128cbcID) |
||||||
|
|
||||||
|
kr := &kexResult{Hash: crypto.SHA1} |
||||||
|
algs := directionAlgorithms{ |
||||||
|
Cipher: aes128cbcID, |
||||||
|
MAC: "hmac-sha1", |
||||||
|
Compression: "none", |
||||||
|
} |
||||||
|
client, err := newPacketCipher(clientKeys, algs, kr) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("newPacketCipher(client): %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
want := "bla bla" |
||||||
|
input := []byte(want) |
||||||
|
buf := &bytes.Buffer{} |
||||||
|
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
packetSize := buf.Len() |
||||||
|
buf.Write(make([]byte, 2*maxPacket)) |
||||||
|
|
||||||
|
// We corrupt each byte, but this usually will only test the
|
||||||
|
// 'packet too large' or 'MAC failure' cases.
|
||||||
|
lastRead := -1 |
||||||
|
for i := 0; i < packetSize; i++ { |
||||||
|
server, err := newPacketCipher(clientKeys, algs, kr) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("newPacketCipher(client): %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
fresh := &bytes.Buffer{} |
||||||
|
fresh.Write(buf.Bytes()) |
||||||
|
fresh.Bytes()[i] ^= 0x01 |
||||||
|
|
||||||
|
before := fresh.Len() |
||||||
|
_, err = server.readPacket(0, fresh) |
||||||
|
if err == nil { |
||||||
|
t.Errorf("corrupt byte %d: readPacket succeeded ", i) |
||||||
|
continue |
||||||
|
} |
||||||
|
if _, ok := err.(cbcError); !ok { |
||||||
|
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
after := fresh.Len() |
||||||
|
bytesRead := before - after |
||||||
|
if bytesRead < maxPacket { |
||||||
|
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if i > 0 && bytesRead != lastRead { |
||||||
|
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) |
||||||
|
} |
||||||
|
lastRead = bytesRead |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,213 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
// Client implements a traditional SSH client that supports shells,
|
||||||
|
// subprocesses, port forwarding and tunneled dialing.
|
||||||
|
type Client struct { |
||||||
|
Conn |
||||||
|
|
||||||
|
forwards forwardList // forwarded tcpip connections from the remote side
|
||||||
|
mu sync.Mutex |
||||||
|
channelHandlers map[string]chan NewChannel |
||||||
|
} |
||||||
|
|
||||||
|
// HandleChannelOpen returns a channel on which NewChannel requests
|
||||||
|
// for the given type are sent. If the type already is being handled,
|
||||||
|
// nil is returned. The channel is closed when the connection is closed.
|
||||||
|
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { |
||||||
|
c.mu.Lock() |
||||||
|
defer c.mu.Unlock() |
||||||
|
if c.channelHandlers == nil { |
||||||
|
// The SSH channel has been closed.
|
||||||
|
c := make(chan NewChannel) |
||||||
|
close(c) |
||||||
|
return c |
||||||
|
} |
||||||
|
|
||||||
|
ch := c.channelHandlers[channelType] |
||||||
|
if ch != nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
ch = make(chan NewChannel, 16) |
||||||
|
c.channelHandlers[channelType] = ch |
||||||
|
return ch |
||||||
|
} |
||||||
|
|
||||||
|
// NewClient creates a Client on top of the given connection.
|
||||||
|
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { |
||||||
|
conn := &Client{ |
||||||
|
Conn: c, |
||||||
|
channelHandlers: make(map[string]chan NewChannel, 1), |
||||||
|
} |
||||||
|
|
||||||
|
go conn.handleGlobalRequests(reqs) |
||||||
|
go conn.handleChannelOpens(chans) |
||||||
|
go func() { |
||||||
|
conn.Wait() |
||||||
|
conn.forwards.closeAll() |
||||||
|
}() |
||||||
|
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) |
||||||
|
return conn |
||||||
|
} |
||||||
|
|
||||||
|
// NewClientConn establishes an authenticated SSH connection using c
|
||||||
|
// as the underlying transport. The Request and NewChannel channels
|
||||||
|
// must be serviced or the connection will hang.
|
||||||
|
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { |
||||||
|
fullConf := *config |
||||||
|
fullConf.SetDefaults() |
||||||
|
conn := &connection{ |
||||||
|
sshConn: sshConn{conn: c}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := conn.clientHandshake(addr, &fullConf); err != nil { |
||||||
|
c.Close() |
||||||
|
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) |
||||||
|
} |
||||||
|
conn.mux = newMux(conn.transport) |
||||||
|
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil |
||||||
|
} |
||||||
|
|
||||||
|
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
||||||
|
// 7.
|
||||||
|
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { |
||||||
|
if config.ClientVersion != "" { |
||||||
|
c.clientVersion = []byte(config.ClientVersion) |
||||||
|
} else { |
||||||
|
c.clientVersion = []byte(packageVersion) |
||||||
|
} |
||||||
|
var err error |
||||||
|
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
c.transport = newClientTransport( |
||||||
|
newTransport(c.sshConn.conn, config.Rand, true /* is client */), |
||||||
|
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) |
||||||
|
if err := c.transport.requestKeyChange(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if packet, err := c.transport.readPacket(); err != nil { |
||||||
|
return err |
||||||
|
} else if packet[0] != msgNewKeys { |
||||||
|
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||||
|
} |
||||||
|
|
||||||
|
// We just did the key change, so the session ID is established.
|
||||||
|
c.sessionID = c.transport.getSessionID() |
||||||
|
|
||||||
|
return c.clientAuthenticate(config) |
||||||
|
} |
||||||
|
|
||||||
|
// verifyHostKeySignature verifies the host key obtained in the key
|
||||||
|
// exchange.
|
||||||
|
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { |
||||||
|
sig, rest, ok := parseSignatureBody(result.Signature) |
||||||
|
if len(rest) > 0 || !ok { |
||||||
|
return errors.New("ssh: signature parse error") |
||||||
|
} |
||||||
|
|
||||||
|
return hostKey.Verify(result.H, sig) |
||||||
|
} |
||||||
|
|
||||||
|
// NewSession opens a new Session for this client. (A session is a remote
|
||||||
|
// execution of a program.)
|
||||||
|
func (c *Client) NewSession() (*Session, error) { |
||||||
|
ch, in, err := c.OpenChannel("session", nil) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return newSession(ch, in) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) handleGlobalRequests(incoming <-chan *Request) { |
||||||
|
for r := range incoming { |
||||||
|
// This handles keepalive messages and matches
|
||||||
|
// the behaviour of OpenSSH.
|
||||||
|
r.Reply(false, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// handleChannelOpens channel open messages from the remote side.
|
||||||
|
func (c *Client) handleChannelOpens(in <-chan NewChannel) { |
||||||
|
for ch := range in { |
||||||
|
c.mu.Lock() |
||||||
|
handler := c.channelHandlers[ch.ChannelType()] |
||||||
|
c.mu.Unlock() |
||||||
|
|
||||||
|
if handler != nil { |
||||||
|
handler <- ch |
||||||
|
} else { |
||||||
|
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
c.mu.Lock() |
||||||
|
for _, ch := range c.channelHandlers { |
||||||
|
close(ch) |
||||||
|
} |
||||||
|
c.channelHandlers = nil |
||||||
|
c.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// Dial starts a client connection to the given SSH server. It is a
|
||||||
|
// convenience function that connects to the given network address,
|
||||||
|
// initiates the SSH handshake, and then sets up a Client. For access
|
||||||
|
// to incoming channels and requests, use net.Dial with NewClientConn
|
||||||
|
// instead.
|
||||||
|
func Dial(network, addr string, config *ClientConfig) (*Client, error) { |
||||||
|
conn, err := net.Dial(network, addr) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
c, chans, reqs, err := NewClientConn(conn, addr, config) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return NewClient(c, chans, reqs), nil |
||||||
|
} |
||||||
|
|
||||||
|
// A ClientConfig structure is used to configure a Client. It must not be
|
||||||
|
// modified after having been passed to an SSH function.
|
||||||
|
type ClientConfig struct { |
||||||
|
// Config contains configuration that is shared between clients and
|
||||||
|
// servers.
|
||||||
|
Config |
||||||
|
|
||||||
|
// User contains the username to authenticate as.
|
||||||
|
User string |
||||||
|
|
||||||
|
// Auth contains possible authentication methods to use with the
|
||||||
|
// server. Only the first instance of a particular RFC 4252 method will
|
||||||
|
// be used during authentication.
|
||||||
|
Auth []AuthMethod |
||||||
|
|
||||||
|
// HostKeyCallback, if not nil, is called during the cryptographic
|
||||||
|
// handshake to validate the server's host key. A nil HostKeyCallback
|
||||||
|
// implies that all host keys are accepted.
|
||||||
|
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||||
|
|
||||||
|
// ClientVersion contains the version identification string that will
|
||||||
|
// be used for the connection. If empty, a reasonable default is used.
|
||||||
|
ClientVersion string |
||||||
|
|
||||||
|
// HostKeyAlgorithms lists the key types that the client will
|
||||||
|
// accept from the server as host key, in order of
|
||||||
|
// preference. If empty, a reasonable default is used. Any
|
||||||
|
// string returned from PublicKey.Type method may be used, or
|
||||||
|
// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
|
||||||
|
HostKeyAlgorithms []string |
||||||
|
} |
@ -0,0 +1,441 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
) |
||||||
|
|
||||||
|
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
||||||
|
func (c *connection) clientAuthenticate(config *ClientConfig) error { |
||||||
|
// initiate user auth session
|
||||||
|
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
packet, err := c.transport.readPacket() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
var serviceAccept serviceAcceptMsg |
||||||
|
if err := Unmarshal(packet, &serviceAccept); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// during the authentication phase the client first attempts the "none" method
|
||||||
|
// then any untried methods suggested by the server.
|
||||||
|
tried := make(map[string]bool) |
||||||
|
var lastMethods []string |
||||||
|
for auth := AuthMethod(new(noneAuth)); auth != nil; { |
||||||
|
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if ok { |
||||||
|
// success
|
||||||
|
return nil |
||||||
|
} |
||||||
|
tried[auth.method()] = true |
||||||
|
if methods == nil { |
||||||
|
methods = lastMethods |
||||||
|
} |
||||||
|
lastMethods = methods |
||||||
|
|
||||||
|
auth = nil |
||||||
|
|
||||||
|
findNext: |
||||||
|
for _, a := range config.Auth { |
||||||
|
candidateMethod := a.method() |
||||||
|
if tried[candidateMethod] { |
||||||
|
continue |
||||||
|
} |
||||||
|
for _, meth := range methods { |
||||||
|
if meth == candidateMethod { |
||||||
|
auth = a |
||||||
|
break findNext |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) |
||||||
|
} |
||||||
|
|
||||||
|
func keys(m map[string]bool) []string { |
||||||
|
s := make([]string, 0, len(m)) |
||||||
|
|
||||||
|
for key := range m { |
||||||
|
s = append(s, key) |
||||||
|
} |
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
||||||
|
type AuthMethod interface { |
||||||
|
// auth authenticates user over transport t.
|
||||||
|
// Returns true if authentication is successful.
|
||||||
|
// If authentication is not successful, a []string of alternative
|
||||||
|
// method names is returned. If the slice is nil, it will be ignored
|
||||||
|
// and the previous set of possible methods will be reused.
|
||||||
|
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) |
||||||
|
|
||||||
|
// method returns the RFC 4252 method name.
|
||||||
|
method() string |
||||||
|
} |
||||||
|
|
||||||
|
// "none" authentication, RFC 4252 section 5.2.
|
||||||
|
type noneAuth int |
||||||
|
|
||||||
|
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||||
|
if err := c.writePacket(Marshal(&userAuthRequestMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: "none", |
||||||
|
})); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return handleAuthResponse(c) |
||||||
|
} |
||||||
|
|
||||||
|
func (n *noneAuth) method() string { |
||||||
|
return "none" |
||||||
|
} |
||||||
|
|
||||||
|
// passwordCallback is an AuthMethod that fetches the password through
|
||||||
|
// a function call, e.g. by prompting the user.
|
||||||
|
type passwordCallback func() (password string, err error) |
||||||
|
|
||||||
|
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||||
|
type passwordAuthMsg struct { |
||||||
|
User string `sshtype:"50"` |
||||||
|
Service string |
||||||
|
Method string |
||||||
|
Reply bool |
||||||
|
Password string |
||||||
|
} |
||||||
|
|
||||||
|
pw, err := cb() |
||||||
|
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
||||||
|
// The program may only find out that the user doesn't have a password
|
||||||
|
// when prompting.
|
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if err := c.writePacket(Marshal(&passwordAuthMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: cb.method(), |
||||||
|
Reply: false, |
||||||
|
Password: pw, |
||||||
|
})); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return handleAuthResponse(c) |
||||||
|
} |
||||||
|
|
||||||
|
func (cb passwordCallback) method() string { |
||||||
|
return "password" |
||||||
|
} |
||||||
|
|
||||||
|
// Password returns an AuthMethod using the given password.
|
||||||
|
func Password(secret string) AuthMethod { |
||||||
|
return passwordCallback(func() (string, error) { return secret, nil }) |
||||||
|
} |
||||||
|
|
||||||
|
// PasswordCallback returns an AuthMethod that uses a callback for
|
||||||
|
// fetching a password.
|
||||||
|
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { |
||||||
|
return passwordCallback(prompt) |
||||||
|
} |
||||||
|
|
||||||
|
type publickeyAuthMsg struct { |
||||||
|
User string `sshtype:"50"` |
||||||
|
Service string |
||||||
|
Method string |
||||||
|
// HasSig indicates to the receiver packet that the auth request is signed and
|
||||||
|
// should be used for authentication of the request.
|
||||||
|
HasSig bool |
||||||
|
Algoname string |
||||||
|
PubKey []byte |
||||||
|
// Sig is tagged with "rest" so Marshal will exclude it during
|
||||||
|
// validateKey
|
||||||
|
Sig []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// publicKeyCallback is an AuthMethod that uses a set of key
|
||||||
|
// pairs for authentication.
|
||||||
|
type publicKeyCallback func() ([]Signer, error) |
||||||
|
|
||||||
|
func (cb publicKeyCallback) method() string { |
||||||
|
return "publickey" |
||||||
|
} |
||||||
|
|
||||||
|
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||||
|
// Authentication is performed in two stages. The first stage sends an
|
||||||
|
// enquiry to test if each key is acceptable to the remote. The second
|
||||||
|
// stage attempts to authenticate with the valid keys obtained in the
|
||||||
|
// first stage.
|
||||||
|
|
||||||
|
signers, err := cb() |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
var validKeys []Signer |
||||||
|
for _, signer := range signers { |
||||||
|
if ok, err := validateKey(signer.PublicKey(), user, c); ok { |
||||||
|
validKeys = append(validKeys, signer) |
||||||
|
} else { |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// methods that may continue if this auth is not successful.
|
||||||
|
var methods []string |
||||||
|
for _, signer := range validKeys { |
||||||
|
pub := signer.PublicKey() |
||||||
|
|
||||||
|
pubKey := pub.Marshal() |
||||||
|
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: cb.method(), |
||||||
|
}, []byte(pub.Type()), pubKey)) |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// manually wrap the serialized signature in a string
|
||||||
|
s := Marshal(sign) |
||||||
|
sig := make([]byte, stringLength(len(s))) |
||||||
|
marshalString(sig, s) |
||||||
|
msg := publickeyAuthMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: cb.method(), |
||||||
|
HasSig: true, |
||||||
|
Algoname: pub.Type(), |
||||||
|
PubKey: pubKey, |
||||||
|
Sig: sig, |
||||||
|
} |
||||||
|
p := Marshal(&msg) |
||||||
|
if err := c.writePacket(p); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
var success bool |
||||||
|
success, methods, err = handleAuthResponse(c) |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
if success { |
||||||
|
return success, methods, err |
||||||
|
} |
||||||
|
} |
||||||
|
return false, methods, nil |
||||||
|
} |
||||||
|
|
||||||
|
// validateKey validates the key provided is acceptable to the server.
|
||||||
|
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { |
||||||
|
pubKey := key.Marshal() |
||||||
|
msg := publickeyAuthMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: "publickey", |
||||||
|
HasSig: false, |
||||||
|
Algoname: key.Type(), |
||||||
|
PubKey: pubKey, |
||||||
|
} |
||||||
|
if err := c.writePacket(Marshal(&msg)); err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
|
||||||
|
return confirmKeyAck(key, c) |
||||||
|
} |
||||||
|
|
||||||
|
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { |
||||||
|
pubKey := key.Marshal() |
||||||
|
algoname := key.Type() |
||||||
|
|
||||||
|
for { |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
switch packet[0] { |
||||||
|
case msgUserAuthBanner: |
||||||
|
// TODO(gpaul): add callback to present the banner to the user
|
||||||
|
case msgUserAuthPubKeyOk: |
||||||
|
var msg userAuthPubKeyOkMsg |
||||||
|
if err := Unmarshal(packet, &msg); err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { |
||||||
|
return false, nil |
||||||
|
} |
||||||
|
return true, nil |
||||||
|
case msgUserAuthFailure: |
||||||
|
return false, nil |
||||||
|
default: |
||||||
|
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// PublicKeys returns an AuthMethod that uses the given key
|
||||||
|
// pairs.
|
||||||
|
func PublicKeys(signers ...Signer) AuthMethod { |
||||||
|
return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) |
||||||
|
} |
||||||
|
|
||||||
|
// PublicKeysCallback returns an AuthMethod that runs the given
|
||||||
|
// function to obtain a list of key pairs.
|
||||||
|
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { |
||||||
|
return publicKeyCallback(getSigners) |
||||||
|
} |
||||||
|
|
||||||
|
// handleAuthResponse returns whether the preceding authentication request succeeded
|
||||||
|
// along with a list of remaining authentication methods to try next and
|
||||||
|
// an error if an unexpected response was received.
|
||||||
|
func handleAuthResponse(c packetConn) (bool, []string, error) { |
||||||
|
for { |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
switch packet[0] { |
||||||
|
case msgUserAuthBanner: |
||||||
|
// TODO: add callback to present the banner to the user
|
||||||
|
case msgUserAuthFailure: |
||||||
|
var msg userAuthFailureMsg |
||||||
|
if err := Unmarshal(packet, &msg); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
return false, msg.Methods, nil |
||||||
|
case msgUserAuthSuccess: |
||||||
|
return true, nil, nil |
||||||
|
case msgDisconnect: |
||||||
|
return false, nil, io.EOF |
||||||
|
default: |
||||||
|
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// KeyboardInteractiveChallenge should print questions, optionally
|
||||||
|
// disabling echoing (e.g. for passwords), and return all the answers.
|
||||||
|
// Challenge may be called multiple times in a single session. After
|
||||||
|
// successful authentication, the server may send a challenge with no
|
||||||
|
// questions, for which the user and instruction messages should be
|
||||||
|
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
||||||
|
// both CLI and GUI environments.
|
||||||
|
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) |
||||||
|
|
||||||
|
// KeyboardInteractive returns a AuthMethod using a prompt/response
|
||||||
|
// sequence controlled by the server.
|
||||||
|
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { |
||||||
|
return challenge |
||||||
|
} |
||||||
|
|
||||||
|
func (cb KeyboardInteractiveChallenge) method() string { |
||||||
|
return "keyboard-interactive" |
||||||
|
} |
||||||
|
|
||||||
|
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||||
|
type initiateMsg struct { |
||||||
|
User string `sshtype:"50"` |
||||||
|
Service string |
||||||
|
Method string |
||||||
|
Language string |
||||||
|
Submethods string |
||||||
|
} |
||||||
|
|
||||||
|
if err := c.writePacket(Marshal(&initiateMsg{ |
||||||
|
User: user, |
||||||
|
Service: serviceSSH, |
||||||
|
Method: "keyboard-interactive", |
||||||
|
})); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
for { |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// like handleAuthResponse, but with less options.
|
||||||
|
switch packet[0] { |
||||||
|
case msgUserAuthBanner: |
||||||
|
// TODO: Print banners during userauth.
|
||||||
|
continue |
||||||
|
case msgUserAuthInfoRequest: |
||||||
|
// OK
|
||||||
|
case msgUserAuthFailure: |
||||||
|
var msg userAuthFailureMsg |
||||||
|
if err := Unmarshal(packet, &msg); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
return false, msg.Methods, nil |
||||||
|
case msgUserAuthSuccess: |
||||||
|
return true, nil, nil |
||||||
|
default: |
||||||
|
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) |
||||||
|
} |
||||||
|
|
||||||
|
var msg userAuthInfoRequestMsg |
||||||
|
if err := Unmarshal(packet, &msg); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// Manually unpack the prompt/echo pairs.
|
||||||
|
rest := msg.Prompts |
||||||
|
var prompts []string |
||||||
|
var echos []bool |
||||||
|
for i := 0; i < int(msg.NumPrompts); i++ { |
||||||
|
prompt, r, ok := parseString(rest) |
||||||
|
if !ok || len(r) == 0 { |
||||||
|
return false, nil, errors.New("ssh: prompt format error") |
||||||
|
} |
||||||
|
prompts = append(prompts, string(prompt)) |
||||||
|
echos = append(echos, r[0] != 0) |
||||||
|
rest = r[1:] |
||||||
|
} |
||||||
|
|
||||||
|
if len(rest) != 0 { |
||||||
|
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") |
||||||
|
} |
||||||
|
|
||||||
|
answers, err := cb(msg.User, msg.Instruction, prompts, echos) |
||||||
|
if err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if len(answers) != len(prompts) { |
||||||
|
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") |
||||||
|
} |
||||||
|
responseLength := 1 + 4 |
||||||
|
for _, a := range answers { |
||||||
|
responseLength += stringLength(len(a)) |
||||||
|
} |
||||||
|
serialized := make([]byte, responseLength) |
||||||
|
p := serialized |
||||||
|
p[0] = msgUserAuthInfoResponse |
||||||
|
p = p[1:] |
||||||
|
p = marshalUint32(p, uint32(len(answers))) |
||||||
|
for _, a := range answers { |
||||||
|
p = marshalString(p, []byte(a)) |
||||||
|
} |
||||||
|
|
||||||
|
if err := c.writePacket(serialized); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,393 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
type keyboardInteractive map[string]string |
||||||
|
|
||||||
|
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { |
||||||
|
var answers []string |
||||||
|
for _, q := range questions { |
||||||
|
answers = append(answers, cr[q]) |
||||||
|
} |
||||||
|
return answers, nil |
||||||
|
} |
||||||
|
|
||||||
|
// reused internally by tests
|
||||||
|
var clientPassword = "tiger" |
||||||
|
|
||||||
|
// tryAuth runs a handshake with a given config against an SSH server
|
||||||
|
// with config serverConfig
|
||||||
|
func tryAuth(t *testing.T, config *ClientConfig) error { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
certChecker := CertChecker{ |
||||||
|
IsAuthority: func(k PublicKey) bool { |
||||||
|
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) |
||||||
|
}, |
||||||
|
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||||
|
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) |
||||||
|
}, |
||||||
|
IsRevoked: func(c *Certificate) bool { |
||||||
|
return c.Serial == 666 |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
serverConfig := &ServerConfig{ |
||||||
|
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { |
||||||
|
if conn.User() == "testuser" && string(pass) == clientPassword { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
return nil, errors.New("password auth failed") |
||||||
|
}, |
||||||
|
PublicKeyCallback: certChecker.Authenticate, |
||||||
|
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { |
||||||
|
ans, err := challenge("user", |
||||||
|
"instruction", |
||||||
|
[]string{"question1", "question2"}, |
||||||
|
[]bool{true, true}) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" |
||||||
|
if ok { |
||||||
|
challenge("user", "motd", nil, nil) |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
return nil, errors.New("keyboard-interactive failed") |
||||||
|
}, |
||||||
|
AuthLogCallback: func(conn ConnMetadata, method string, err error) { |
||||||
|
t.Logf("user %q, method %q: %v", conn.User(), method, err) |
||||||
|
}, |
||||||
|
} |
||||||
|
serverConfig.AddHostKey(testSigners["rsa"]) |
||||||
|
|
||||||
|
go newServer(c1, serverConfig) |
||||||
|
_, _, _, err = NewClientConn(c2, "", config) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func TestClientAuthPublicKey(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["rsa"]), |
||||||
|
}, |
||||||
|
} |
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %s", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthMethodPassword(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
Password(clientPassword), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %s", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthMethodFallback(t *testing.T) { |
||||||
|
var passwordCalled bool |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["rsa"]), |
||||||
|
PasswordCallback( |
||||||
|
func() (string, error) { |
||||||
|
passwordCalled = true |
||||||
|
return "WRONG", nil |
||||||
|
}), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
if passwordCalled { |
||||||
|
t.Errorf("password auth tried before public-key auth.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthMethodWrongPassword(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
Password("wrong"), |
||||||
|
PublicKeys(testSigners["rsa"]), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %s", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthMethodKeyboardInteractive(t *testing.T) { |
||||||
|
answers := keyboardInteractive(map[string]string{ |
||||||
|
"question1": "answer1", |
||||||
|
"question2": "answer2", |
||||||
|
}) |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
KeyboardInteractive(answers.Challenge), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %s", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { |
||||||
|
answers := keyboardInteractive(map[string]string{ |
||||||
|
"question1": "answer1", |
||||||
|
"question2": "WRONG", |
||||||
|
}) |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
KeyboardInteractive(answers.Challenge), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err == nil { |
||||||
|
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// the mock server will only authenticate ssh-rsa keys
|
||||||
|
func TestAuthMethodInvalidPublicKey(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["dsa"]), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
if err := tryAuth(t, config); err == nil { |
||||||
|
t.Fatalf("dsa private key should not have authenticated with rsa public key") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// the client should authenticate with the second key
|
||||||
|
func TestAuthMethodRSAandDSA(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["dsa"], testSigners["rsa"]), |
||||||
|
}, |
||||||
|
} |
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("client could not authenticate with rsa key: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestClientHMAC(t *testing.T) { |
||||||
|
for _, mac := range supportedMACs { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["rsa"]), |
||||||
|
}, |
||||||
|
Config: Config{ |
||||||
|
MACs: []string{mac}, |
||||||
|
}, |
||||||
|
} |
||||||
|
if err := tryAuth(t, config); err != nil { |
||||||
|
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// issue 4285.
|
||||||
|
func TestClientUnsupportedCipher(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(), |
||||||
|
}, |
||||||
|
Config: Config{ |
||||||
|
Ciphers: []string{"aes128-cbc"}, // not currently supported
|
||||||
|
}, |
||||||
|
} |
||||||
|
if err := tryAuth(t, config); err == nil { |
||||||
|
t.Errorf("expected no ciphers in common") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestClientUnsupportedKex(t *testing.T) { |
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(), |
||||||
|
}, |
||||||
|
Config: Config{ |
||||||
|
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
|
||||||
|
}, |
||||||
|
} |
||||||
|
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { |
||||||
|
t.Errorf("got %v, expected 'common algorithm'", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestClientLoginCert(t *testing.T) { |
||||||
|
cert := &Certificate{ |
||||||
|
Key: testPublicKeys["rsa"], |
||||||
|
ValidBefore: CertTimeInfinity, |
||||||
|
CertType: UserCert, |
||||||
|
} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewCertSigner: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
clientConfig := &ClientConfig{ |
||||||
|
User: "user", |
||||||
|
} |
||||||
|
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) |
||||||
|
|
||||||
|
t.Log("should succeed") |
||||||
|
if err := tryAuth(t, clientConfig); err != nil { |
||||||
|
t.Errorf("cert login failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("corrupted signature") |
||||||
|
cert.Signature.Blob[0]++ |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login passed with corrupted sig") |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("revoked") |
||||||
|
cert.Serial = 666 |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("revoked cert login succeeded") |
||||||
|
} |
||||||
|
cert.Serial = 1 |
||||||
|
|
||||||
|
t.Log("sign with wrong key") |
||||||
|
cert.SignCert(rand.Reader, testSigners["dsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login passed with non-authoritive key") |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("host cert") |
||||||
|
cert.CertType = HostCert |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login passed with wrong type") |
||||||
|
} |
||||||
|
cert.CertType = UserCert |
||||||
|
|
||||||
|
t.Log("principal specified") |
||||||
|
cert.ValidPrincipals = []string{"user"} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err != nil { |
||||||
|
t.Errorf("cert login failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("wrong principal specified") |
||||||
|
cert.ValidPrincipals = []string{"fred"} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login passed with wrong principal") |
||||||
|
} |
||||||
|
cert.ValidPrincipals = nil |
||||||
|
|
||||||
|
t.Log("added critical option") |
||||||
|
cert.CriticalOptions = map[string]string{"root-access": "yes"} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login passed with unrecognized critical option") |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("allowed source address") |
||||||
|
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err != nil { |
||||||
|
t.Errorf("cert login with source-address failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
t.Log("disallowed source address") |
||||||
|
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} |
||||||
|
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||||
|
if err := tryAuth(t, clientConfig); err == nil { |
||||||
|
t.Errorf("cert login with source-address succeeded") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func testPermissionsPassing(withPermissions bool, t *testing.T) { |
||||||
|
serverConfig := &ServerConfig{ |
||||||
|
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||||
|
if conn.User() == "nopermissions" { |
||||||
|
return nil, nil |
||||||
|
} else { |
||||||
|
return &Permissions{}, nil |
||||||
|
} |
||||||
|
}, |
||||||
|
} |
||||||
|
serverConfig.AddHostKey(testSigners["rsa"]) |
||||||
|
|
||||||
|
clientConfig := &ClientConfig{ |
||||||
|
Auth: []AuthMethod{ |
||||||
|
PublicKeys(testSigners["rsa"]), |
||||||
|
}, |
||||||
|
} |
||||||
|
if withPermissions { |
||||||
|
clientConfig.User = "permissions" |
||||||
|
} else { |
||||||
|
clientConfig.User = "nopermissions" |
||||||
|
} |
||||||
|
|
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
go NewClientConn(c2, "", clientConfig) |
||||||
|
serverConn, err := newServer(c1, serverConfig) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
if p := serverConn.Permissions; (p != nil) != withPermissions { |
||||||
|
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestPermissionsPassing(t *testing.T) { |
||||||
|
testPermissionsPassing(true, t) |
||||||
|
} |
||||||
|
|
||||||
|
func TestNoPermissionsPassing(t *testing.T) { |
||||||
|
testPermissionsPassing(false, t) |
||||||
|
} |
@ -0,0 +1,39 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"net" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func testClientVersion(t *testing.T, config *ClientConfig, expected string) { |
||||||
|
clientConn, serverConn := net.Pipe() |
||||||
|
defer clientConn.Close() |
||||||
|
receivedVersion := make(chan string, 1) |
||||||
|
go func() { |
||||||
|
version, err := readVersion(serverConn) |
||||||
|
if err != nil { |
||||||
|
receivedVersion <- "" |
||||||
|
} else { |
||||||
|
receivedVersion <- string(version) |
||||||
|
} |
||||||
|
serverConn.Close() |
||||||
|
}() |
||||||
|
NewClientConn(clientConn, "", config) |
||||||
|
actual := <-receivedVersion |
||||||
|
if actual != expected { |
||||||
|
t.Fatalf("got %s; want %s", actual, expected) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCustomClientVersion(t *testing.T) { |
||||||
|
version := "Test-Client-Version-0.0" |
||||||
|
testClientVersion(t, &ClientConfig{ClientVersion: version}, version) |
||||||
|
} |
||||||
|
|
||||||
|
func TestDefaultClientVersion(t *testing.T) { |
||||||
|
testClientVersion(t, &ClientConfig{}, packageVersion) |
||||||
|
} |
@ -0,0 +1,354 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto" |
||||||
|
"crypto/rand" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"sync" |
||||||
|
|
||||||
|
_ "crypto/sha1" |
||||||
|
_ "crypto/sha256" |
||||||
|
_ "crypto/sha512" |
||||||
|
) |
||||||
|
|
||||||
|
// These are string constants in the SSH protocol.
|
||||||
|
const ( |
||||||
|
compressionNone = "none" |
||||||
|
serviceUserAuth = "ssh-userauth" |
||||||
|
serviceSSH = "ssh-connection" |
||||||
|
) |
||||||
|
|
||||||
|
// supportedCiphers specifies the supported ciphers in preference order.
|
||||||
|
var supportedCiphers = []string{ |
||||||
|
"aes128-ctr", "aes192-ctr", "aes256-ctr", |
||||||
|
"aes128-gcm@openssh.com", |
||||||
|
"arcfour256", "arcfour128", |
||||||
|
} |
||||||
|
|
||||||
|
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
||||||
|
// preference order.
|
||||||
|
var supportedKexAlgos = []string{ |
||||||
|
kexAlgoCurve25519SHA256, |
||||||
|
// P384 and P521 are not constant-time yet, but since we don't
|
||||||
|
// reuse ephemeral keys, using them for ECDH should be OK.
|
||||||
|
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, |
||||||
|
kexAlgoDH14SHA1, kexAlgoDH1SHA1, |
||||||
|
} |
||||||
|
|
||||||
|
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
|
||||||
|
// of authenticating servers) in preference order.
|
||||||
|
var supportedHostKeyAlgos = []string{ |
||||||
|
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, |
||||||
|
CertAlgoECDSA384v01, CertAlgoECDSA521v01, |
||||||
|
|
||||||
|
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||||
|
KeyAlgoRSA, KeyAlgoDSA, |
||||||
|
} |
||||||
|
|
||||||
|
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
||||||
|
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
||||||
|
// because they have reached the end of their useful life.
|
||||||
|
var supportedMACs = []string{ |
||||||
|
"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", |
||||||
|
} |
||||||
|
|
||||||
|
var supportedCompressions = []string{compressionNone} |
||||||
|
|
||||||
|
// hashFuncs keeps the mapping of supported algorithms to their respective
|
||||||
|
// hashes needed for signature verification.
|
||||||
|
var hashFuncs = map[string]crypto.Hash{ |
||||||
|
KeyAlgoRSA: crypto.SHA1, |
||||||
|
KeyAlgoDSA: crypto.SHA1, |
||||||
|
KeyAlgoECDSA256: crypto.SHA256, |
||||||
|
KeyAlgoECDSA384: crypto.SHA384, |
||||||
|
KeyAlgoECDSA521: crypto.SHA512, |
||||||
|
CertAlgoRSAv01: crypto.SHA1, |
||||||
|
CertAlgoDSAv01: crypto.SHA1, |
||||||
|
CertAlgoECDSA256v01: crypto.SHA256, |
||||||
|
CertAlgoECDSA384v01: crypto.SHA384, |
||||||
|
CertAlgoECDSA521v01: crypto.SHA512, |
||||||
|
} |
||||||
|
|
||||||
|
// unexpectedMessageError results when the SSH message that we received didn't
|
||||||
|
// match what we wanted.
|
||||||
|
func unexpectedMessageError(expected, got uint8) error { |
||||||
|
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) |
||||||
|
} |
||||||
|
|
||||||
|
// parseError results from a malformed SSH message.
|
||||||
|
func parseError(tag uint8) error { |
||||||
|
return fmt.Errorf("ssh: parse error in message type %d", tag) |
||||||
|
} |
||||||
|
|
||||||
|
func findCommon(what string, client []string, server []string) (common string, err error) { |
||||||
|
for _, c := range client { |
||||||
|
for _, s := range server { |
||||||
|
if c == s { |
||||||
|
return c, nil |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) |
||||||
|
} |
||||||
|
|
||||||
|
type directionAlgorithms struct { |
||||||
|
Cipher string |
||||||
|
MAC string |
||||||
|
Compression string |
||||||
|
} |
||||||
|
|
||||||
|
type algorithms struct { |
||||||
|
kex string |
||||||
|
hostKey string |
||||||
|
w directionAlgorithms |
||||||
|
r directionAlgorithms |
||||||
|
} |
||||||
|
|
||||||
|
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { |
||||||
|
result := &algorithms{} |
||||||
|
|
||||||
|
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
return result, nil |
||||||
|
} |
||||||
|
|
||||||
|
// If rekeythreshold is too small, we can't make any progress sending
|
||||||
|
// stuff.
|
||||||
|
const minRekeyThreshold uint64 = 256 |
||||||
|
|
||||||
|
// Config contains configuration data common to both ServerConfig and
|
||||||
|
// ClientConfig.
|
||||||
|
type Config struct { |
||||||
|
// Rand provides the source of entropy for cryptographic
|
||||||
|
// primitives. If Rand is nil, the cryptographic random reader
|
||||||
|
// in package crypto/rand will be used.
|
||||||
|
Rand io.Reader |
||||||
|
|
||||||
|
// The maximum number of bytes sent or received after which a
|
||||||
|
// new key is negotiated. It must be at least 256. If
|
||||||
|
// unspecified, 1 gigabyte is used.
|
||||||
|
RekeyThreshold uint64 |
||||||
|
|
||||||
|
// The allowed key exchanges algorithms. If unspecified then a
|
||||||
|
// default set of algorithms is used.
|
||||||
|
KeyExchanges []string |
||||||
|
|
||||||
|
// The allowed cipher algorithms. If unspecified then a sensible
|
||||||
|
// default is used.
|
||||||
|
Ciphers []string |
||||||
|
|
||||||
|
// The allowed MAC algorithms. If unspecified then a sensible default
|
||||||
|
// is used.
|
||||||
|
MACs []string |
||||||
|
} |
||||||
|
|
||||||
|
// SetDefaults sets sensible values for unset fields in config. This is
|
||||||
|
// exported for testing: Configs passed to SSH functions are copied and have
|
||||||
|
// default values set automatically.
|
||||||
|
func (c *Config) SetDefaults() { |
||||||
|
if c.Rand == nil { |
||||||
|
c.Rand = rand.Reader |
||||||
|
} |
||||||
|
if c.Ciphers == nil { |
||||||
|
c.Ciphers = supportedCiphers |
||||||
|
} |
||||||
|
var ciphers []string |
||||||
|
for _, c := range c.Ciphers { |
||||||
|
if cipherModes[c] != nil { |
||||||
|
// reject the cipher if we have no cipherModes definition
|
||||||
|
ciphers = append(ciphers, c) |
||||||
|
} |
||||||
|
} |
||||||
|
c.Ciphers = ciphers |
||||||
|
|
||||||
|
if c.KeyExchanges == nil { |
||||||
|
c.KeyExchanges = supportedKexAlgos |
||||||
|
} |
||||||
|
|
||||||
|
if c.MACs == nil { |
||||||
|
c.MACs = supportedMACs |
||||||
|
} |
||||||
|
|
||||||
|
if c.RekeyThreshold == 0 { |
||||||
|
// RFC 4253, section 9 suggests rekeying after 1G.
|
||||||
|
c.RekeyThreshold = 1 << 30 |
||||||
|
} |
||||||
|
if c.RekeyThreshold < minRekeyThreshold { |
||||||
|
c.RekeyThreshold = minRekeyThreshold |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||||
|
// possession of a private key. See RFC 4252, section 7.
|
||||||
|
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { |
||||||
|
data := struct { |
||||||
|
Session []byte |
||||||
|
Type byte |
||||||
|
User string |
||||||
|
Service string |
||||||
|
Method string |
||||||
|
Sign bool |
||||||
|
Algo []byte |
||||||
|
PubKey []byte |
||||||
|
}{ |
||||||
|
sessionId, |
||||||
|
msgUserAuthRequest, |
||||||
|
req.User, |
||||||
|
req.Service, |
||||||
|
req.Method, |
||||||
|
true, |
||||||
|
algo, |
||||||
|
pubKey, |
||||||
|
} |
||||||
|
return Marshal(data) |
||||||
|
} |
||||||
|
|
||||||
|
func appendU16(buf []byte, n uint16) []byte { |
||||||
|
return append(buf, byte(n>>8), byte(n)) |
||||||
|
} |
||||||
|
|
||||||
|
func appendU32(buf []byte, n uint32) []byte { |
||||||
|
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||||
|
} |
||||||
|
|
||||||
|
func appendU64(buf []byte, n uint64) []byte { |
||||||
|
return append(buf, |
||||||
|
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), |
||||||
|
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||||
|
} |
||||||
|
|
||||||
|
func appendInt(buf []byte, n int) []byte { |
||||||
|
return appendU32(buf, uint32(n)) |
||||||
|
} |
||||||
|
|
||||||
|
func appendString(buf []byte, s string) []byte { |
||||||
|
buf = appendU32(buf, uint32(len(s))) |
||||||
|
buf = append(buf, s...) |
||||||
|
return buf |
||||||
|
} |
||||||
|
|
||||||
|
func appendBool(buf []byte, b bool) []byte { |
||||||
|
if b { |
||||||
|
return append(buf, 1) |
||||||
|
} |
||||||
|
return append(buf, 0) |
||||||
|
} |
||||||
|
|
||||||
|
// newCond is a helper to hide the fact that there is no usable zero
|
||||||
|
// value for sync.Cond.
|
||||||
|
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } |
||||||
|
|
||||||
|
// window represents the buffer available to clients
|
||||||
|
// wishing to write to a channel.
|
||||||
|
type window struct { |
||||||
|
*sync.Cond |
||||||
|
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||||
|
writeWaiters int |
||||||
|
closed bool |
||||||
|
} |
||||||
|
|
||||||
|
// add adds win to the amount of window available
|
||||||
|
// for consumers.
|
||||||
|
func (w *window) add(win uint32) bool { |
||||||
|
// a zero sized window adjust is a noop.
|
||||||
|
if win == 0 { |
||||||
|
return true |
||||||
|
} |
||||||
|
w.L.Lock() |
||||||
|
if w.win+win < win { |
||||||
|
w.L.Unlock() |
||||||
|
return false |
||||||
|
} |
||||||
|
w.win += win |
||||||
|
// It is unusual that multiple goroutines would be attempting to reserve
|
||||||
|
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||||
|
// that additional window is available.
|
||||||
|
w.Broadcast() |
||||||
|
w.L.Unlock() |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
// close sets the window to closed, so all reservations fail
|
||||||
|
// immediately.
|
||||||
|
func (w *window) close() { |
||||||
|
w.L.Lock() |
||||||
|
w.closed = true |
||||||
|
w.Broadcast() |
||||||
|
w.L.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// reserve reserves win from the available window capacity.
|
||||||
|
// If no capacity remains, reserve will block. reserve may
|
||||||
|
// return less than requested.
|
||||||
|
func (w *window) reserve(win uint32) (uint32, error) { |
||||||
|
var err error |
||||||
|
w.L.Lock() |
||||||
|
w.writeWaiters++ |
||||||
|
w.Broadcast() |
||||||
|
for w.win == 0 && !w.closed { |
||||||
|
w.Wait() |
||||||
|
} |
||||||
|
w.writeWaiters-- |
||||||
|
if w.win < win { |
||||||
|
win = w.win |
||||||
|
} |
||||||
|
w.win -= win |
||||||
|
if w.closed { |
||||||
|
err = io.EOF |
||||||
|
} |
||||||
|
w.L.Unlock() |
||||||
|
return win, err |
||||||
|
} |
||||||
|
|
||||||
|
// waitWriterBlocked waits until some goroutine is blocked for further
|
||||||
|
// writes. It is used in tests only.
|
||||||
|
func (w *window) waitWriterBlocked() { |
||||||
|
w.Cond.L.Lock() |
||||||
|
for w.writeWaiters == 0 { |
||||||
|
w.Cond.Wait() |
||||||
|
} |
||||||
|
w.Cond.L.Unlock() |
||||||
|
} |
@ -0,0 +1,144 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
) |
||||||
|
|
||||||
|
// OpenChannelError is returned if the other side rejects an
|
||||||
|
// OpenChannel request.
|
||||||
|
type OpenChannelError struct { |
||||||
|
Reason RejectionReason |
||||||
|
Message string |
||||||
|
} |
||||||
|
|
||||||
|
func (e *OpenChannelError) Error() string { |
||||||
|
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
||||||
|
} |
||||||
|
|
||||||
|
// ConnMetadata holds metadata for the connection.
|
||||||
|
type ConnMetadata interface { |
||||||
|
// User returns the user ID for this connection.
|
||||||
|
// It is empty if no authentication is used.
|
||||||
|
User() string |
||||||
|
|
||||||
|
// SessionID returns the sesson hash, also denoted by H.
|
||||||
|
SessionID() []byte |
||||||
|
|
||||||
|
// ClientVersion returns the client's version string as hashed
|
||||||
|
// into the session ID.
|
||||||
|
ClientVersion() []byte |
||||||
|
|
||||||
|
// ServerVersion returns the server's version string as hashed
|
||||||
|
// into the session ID.
|
||||||
|
ServerVersion() []byte |
||||||
|
|
||||||
|
// RemoteAddr returns the remote address for this connection.
|
||||||
|
RemoteAddr() net.Addr |
||||||
|
|
||||||
|
// LocalAddr returns the local address for this connection.
|
||||||
|
LocalAddr() net.Addr |
||||||
|
} |
||||||
|
|
||||||
|
// Conn represents an SSH connection for both server and client roles.
|
||||||
|
// Conn is the basis for implementing an application layer, such
|
||||||
|
// as ClientConn, which implements the traditional shell access for
|
||||||
|
// clients.
|
||||||
|
type Conn interface { |
||||||
|
ConnMetadata |
||||||
|
|
||||||
|
// SendRequest sends a global request, and returns the
|
||||||
|
// reply. If wantReply is true, it returns the response status
|
||||||
|
// and payload. See also RFC4254, section 4.
|
||||||
|
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) |
||||||
|
|
||||||
|
// OpenChannel tries to open an channel. If the request is
|
||||||
|
// rejected, it returns *OpenChannelError. On success it returns
|
||||||
|
// the SSH Channel and a Go channel for incoming, out-of-band
|
||||||
|
// requests. The Go channel must be serviced, or the
|
||||||
|
// connection will hang.
|
||||||
|
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) |
||||||
|
|
||||||
|
// Close closes the underlying network connection
|
||||||
|
Close() error |
||||||
|
|
||||||
|
// Wait blocks until the connection has shut down, and returns the
|
||||||
|
// error causing the shutdown.
|
||||||
|
Wait() error |
||||||
|
|
||||||
|
// TODO(hanwen): consider exposing:
|
||||||
|
// RequestKeyChange
|
||||||
|
// Disconnect
|
||||||
|
} |
||||||
|
|
||||||
|
// DiscardRequests consumes and rejects all requests from the
|
||||||
|
// passed-in channel.
|
||||||
|
func DiscardRequests(in <-chan *Request) { |
||||||
|
for req := range in { |
||||||
|
if req.WantReply { |
||||||
|
req.Reply(false, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// A connection represents an incoming connection.
|
||||||
|
type connection struct { |
||||||
|
transport *handshakeTransport |
||||||
|
sshConn |
||||||
|
|
||||||
|
// The connection protocol.
|
||||||
|
*mux |
||||||
|
} |
||||||
|
|
||||||
|
func (c *connection) Close() error { |
||||||
|
return c.sshConn.conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// sshconn provides net.Conn metadata, but disallows direct reads and
|
||||||
|
// writes.
|
||||||
|
type sshConn struct { |
||||||
|
conn net.Conn |
||||||
|
|
||||||
|
user string |
||||||
|
sessionID []byte |
||||||
|
clientVersion []byte |
||||||
|
serverVersion []byte |
||||||
|
} |
||||||
|
|
||||||
|
func dup(src []byte) []byte { |
||||||
|
dst := make([]byte, len(src)) |
||||||
|
copy(dst, src) |
||||||
|
return dst |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) User() string { |
||||||
|
return c.user |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) RemoteAddr() net.Addr { |
||||||
|
return c.conn.RemoteAddr() |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) Close() error { |
||||||
|
return c.conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) LocalAddr() net.Addr { |
||||||
|
return c.conn.LocalAddr() |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) SessionID() []byte { |
||||||
|
return dup(c.sessionID) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) ClientVersion() []byte { |
||||||
|
return dup(c.clientVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshConn) ServerVersion() []byte { |
||||||
|
return dup(c.serverVersion) |
||||||
|
} |
@ -0,0 +1,18 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/* |
||||||
|
Package ssh implements an SSH client and server. |
||||||
|
|
||||||
|
SSH is a transport security protocol, an authentication protocol and a |
||||||
|
family of application protocols. The most typical application level |
||||||
|
protocol is a remote shell and this is specifically implemented. However, |
||||||
|
the multiplexed nature of SSH is exposed to users that wish to support |
||||||
|
others. |
||||||
|
|
||||||
|
References: |
||||||
|
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
||||||
|
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||||
|
*/ |
||||||
|
package ssh // import "golang.org/x/crypto/ssh"
|
@ -0,0 +1,211 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"fmt" |
||||||
|
"io/ioutil" |
||||||
|
"log" |
||||||
|
"net" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh" |
||||||
|
"github.com/gogits/gogs/modules/crypto/ssh/terminal" |
||||||
|
) |
||||||
|
|
||||||
|
func ExampleNewServerConn() { |
||||||
|
// An SSH server is represented by a ServerConfig, which holds
|
||||||
|
// certificate details and handles authentication of ServerConns.
|
||||||
|
config := &ssh.ServerConfig{ |
||||||
|
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { |
||||||
|
// Should use constant-time compare (or better, salt+hash) in
|
||||||
|
// a production setting.
|
||||||
|
if c.User() == "testuser" && string(pass) == "tiger" { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
return nil, fmt.Errorf("password rejected for %q", c.User()) |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
privateBytes, err := ioutil.ReadFile("id_rsa") |
||||||
|
if err != nil { |
||||||
|
panic("Failed to load private key") |
||||||
|
} |
||||||
|
|
||||||
|
private, err := ssh.ParsePrivateKey(privateBytes) |
||||||
|
if err != nil { |
||||||
|
panic("Failed to parse private key") |
||||||
|
} |
||||||
|
|
||||||
|
config.AddHostKey(private) |
||||||
|
|
||||||
|
// Once a ServerConfig has been configured, connections can be
|
||||||
|
// accepted.
|
||||||
|
listener, err := net.Listen("tcp", "0.0.0.0:2022") |
||||||
|
if err != nil { |
||||||
|
panic("failed to listen for connection") |
||||||
|
} |
||||||
|
nConn, err := listener.Accept() |
||||||
|
if err != nil { |
||||||
|
panic("failed to accept incoming connection") |
||||||
|
} |
||||||
|
|
||||||
|
// Before use, a handshake must be performed on the incoming
|
||||||
|
// net.Conn.
|
||||||
|
_, chans, reqs, err := ssh.NewServerConn(nConn, config) |
||||||
|
if err != nil { |
||||||
|
panic("failed to handshake") |
||||||
|
} |
||||||
|
// The incoming Request channel must be serviced.
|
||||||
|
go ssh.DiscardRequests(reqs) |
||||||
|
|
||||||
|
// Service the incoming Channel channel.
|
||||||
|
for newChannel := range chans { |
||||||
|
// Channels have a type, depending on the application level
|
||||||
|
// protocol intended. In the case of a shell, the type is
|
||||||
|
// "session" and ServerShell may be used to present a simple
|
||||||
|
// terminal interface.
|
||||||
|
if newChannel.ChannelType() != "session" { |
||||||
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") |
||||||
|
continue |
||||||
|
} |
||||||
|
channel, requests, err := newChannel.Accept() |
||||||
|
if err != nil { |
||||||
|
panic("could not accept channel.") |
||||||
|
} |
||||||
|
|
||||||
|
// Sessions have out-of-band requests such as "shell",
|
||||||
|
// "pty-req" and "env". Here we handle only the
|
||||||
|
// "shell" request.
|
||||||
|
go func(in <-chan *ssh.Request) { |
||||||
|
for req := range in { |
||||||
|
ok := false |
||||||
|
switch req.Type { |
||||||
|
case "shell": |
||||||
|
ok = true |
||||||
|
if len(req.Payload) > 0 { |
||||||
|
// We don't accept any
|
||||||
|
// commands, only the
|
||||||
|
// default shell.
|
||||||
|
ok = false |
||||||
|
} |
||||||
|
} |
||||||
|
req.Reply(ok, nil) |
||||||
|
} |
||||||
|
}(requests) |
||||||
|
|
||||||
|
term := terminal.NewTerminal(channel, "> ") |
||||||
|
|
||||||
|
go func() { |
||||||
|
defer channel.Close() |
||||||
|
for { |
||||||
|
line, err := term.ReadLine() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
fmt.Println(line) |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func ExampleDial() { |
||||||
|
// An SSH client is represented with a ClientConn. Currently only
|
||||||
|
// the "password" authentication method is supported.
|
||||||
|
//
|
||||||
|
// To authenticate with the remote server you must pass at least one
|
||||||
|
// implementation of AuthMethod via the Auth field in ClientConfig.
|
||||||
|
config := &ssh.ClientConfig{ |
||||||
|
User: "username", |
||||||
|
Auth: []ssh.AuthMethod{ |
||||||
|
ssh.Password("yourpassword"), |
||||||
|
}, |
||||||
|
} |
||||||
|
client, err := ssh.Dial("tcp", "yourserver.com:22", config) |
||||||
|
if err != nil { |
||||||
|
panic("Failed to dial: " + err.Error()) |
||||||
|
} |
||||||
|
|
||||||
|
// Each ClientConn can support multiple interactive sessions,
|
||||||
|
// represented by a Session.
|
||||||
|
session, err := client.NewSession() |
||||||
|
if err != nil { |
||||||
|
panic("Failed to create session: " + err.Error()) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
// Once a Session is created, you can execute a single command on
|
||||||
|
// the remote side using the Run method.
|
||||||
|
var b bytes.Buffer |
||||||
|
session.Stdout = &b |
||||||
|
if err := session.Run("/usr/bin/whoami"); err != nil { |
||||||
|
panic("Failed to run: " + err.Error()) |
||||||
|
} |
||||||
|
fmt.Println(b.String()) |
||||||
|
} |
||||||
|
|
||||||
|
func ExampleClient_Listen() { |
||||||
|
config := &ssh.ClientConfig{ |
||||||
|
User: "username", |
||||||
|
Auth: []ssh.AuthMethod{ |
||||||
|
ssh.Password("password"), |
||||||
|
}, |
||||||
|
} |
||||||
|
// Dial your ssh server.
|
||||||
|
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||||
|
if err != nil { |
||||||
|
log.Fatalf("unable to connect: %s", err) |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
// Request the remote side to open port 8080 on all interfaces.
|
||||||
|
l, err := conn.Listen("tcp", "0.0.0.0:8080") |
||||||
|
if err != nil { |
||||||
|
log.Fatalf("unable to register tcp forward: %v", err) |
||||||
|
} |
||||||
|
defer l.Close() |
||||||
|
|
||||||
|
// Serve HTTP with your SSH server acting as a reverse proxy.
|
||||||
|
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||||
|
fmt.Fprintf(resp, "Hello world!\n") |
||||||
|
})) |
||||||
|
} |
||||||
|
|
||||||
|
func ExampleSession_RequestPty() { |
||||||
|
// Create client config
|
||||||
|
config := &ssh.ClientConfig{ |
||||||
|
User: "username", |
||||||
|
Auth: []ssh.AuthMethod{ |
||||||
|
ssh.Password("password"), |
||||||
|
}, |
||||||
|
} |
||||||
|
// Connect to ssh server
|
||||||
|
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||||
|
if err != nil { |
||||||
|
log.Fatalf("unable to connect: %s", err) |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
// Create a session
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
log.Fatalf("unable to create session: %s", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
// Set up terminal modes
|
||||||
|
modes := ssh.TerminalModes{ |
||||||
|
ssh.ECHO: 0, // disable echoing
|
||||||
|
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
||||||
|
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
||||||
|
} |
||||||
|
// Request pseudo terminal
|
||||||
|
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { |
||||||
|
log.Fatalf("request for pseudo terminal failed: %s", err) |
||||||
|
} |
||||||
|
// Start remote shell
|
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
log.Fatalf("failed to start shell: %s", err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,412 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"log" |
||||||
|
"net" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
// debugHandshake, if set, prints messages sent and received. Key
|
||||||
|
// exchange messages are printed as if DH were used, so the debug
|
||||||
|
// messages are wrong when using ECDH.
|
||||||
|
const debugHandshake = false |
||||||
|
|
||||||
|
// keyingTransport is a packet based transport that supports key
|
||||||
|
// changes. It need not be thread-safe. It should pass through
|
||||||
|
// msgNewKeys in both directions.
|
||||||
|
type keyingTransport interface { |
||||||
|
packetConn |
||||||
|
|
||||||
|
// prepareKeyChange sets up a key change. The key change for a
|
||||||
|
// direction will be effected if a msgNewKeys message is sent
|
||||||
|
// or received.
|
||||||
|
prepareKeyChange(*algorithms, *kexResult) error |
||||||
|
|
||||||
|
// getSessionID returns the session ID. prepareKeyChange must
|
||||||
|
// have been called once.
|
||||||
|
getSessionID() []byte |
||||||
|
} |
||||||
|
|
||||||
|
// rekeyingTransport is the interface of handshakeTransport that we
|
||||||
|
// (internally) expose to ClientConn and ServerConn.
|
||||||
|
type rekeyingTransport interface { |
||||||
|
packetConn |
||||||
|
|
||||||
|
// requestKeyChange asks the remote side to change keys. All
|
||||||
|
// writes are blocked until the key change succeeds, which is
|
||||||
|
// signaled by reading a msgNewKeys.
|
||||||
|
requestKeyChange() error |
||||||
|
|
||||||
|
// getSessionID returns the session ID. This is only valid
|
||||||
|
// after the first key change has completed.
|
||||||
|
getSessionID() []byte |
||||||
|
} |
||||||
|
|
||||||
|
// handshakeTransport implements rekeying on top of a keyingTransport
|
||||||
|
// and offers a thread-safe writePacket() interface.
|
||||||
|
type handshakeTransport struct { |
||||||
|
conn keyingTransport |
||||||
|
config *Config |
||||||
|
|
||||||
|
serverVersion []byte |
||||||
|
clientVersion []byte |
||||||
|
|
||||||
|
// hostKeys is non-empty if we are the server. In that case,
|
||||||
|
// it contains all host keys that can be used to sign the
|
||||||
|
// connection.
|
||||||
|
hostKeys []Signer |
||||||
|
|
||||||
|
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
||||||
|
// we accept these key types from the server as host key.
|
||||||
|
hostKeyAlgorithms []string |
||||||
|
|
||||||
|
// On read error, incoming is closed, and readError is set.
|
||||||
|
incoming chan []byte |
||||||
|
readError error |
||||||
|
|
||||||
|
// data for host key checking
|
||||||
|
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||||
|
dialAddress string |
||||||
|
remoteAddr net.Addr |
||||||
|
|
||||||
|
readSinceKex uint64 |
||||||
|
|
||||||
|
// Protects the writing side of the connection
|
||||||
|
mu sync.Mutex |
||||||
|
cond *sync.Cond |
||||||
|
sentInitPacket []byte |
||||||
|
sentInitMsg *kexInitMsg |
||||||
|
writtenSinceKex uint64 |
||||||
|
writeError error |
||||||
|
} |
||||||
|
|
||||||
|
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { |
||||||
|
t := &handshakeTransport{ |
||||||
|
conn: conn, |
||||||
|
serverVersion: serverVersion, |
||||||
|
clientVersion: clientVersion, |
||||||
|
incoming: make(chan []byte, 16), |
||||||
|
config: config, |
||||||
|
} |
||||||
|
t.cond = sync.NewCond(&t.mu) |
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { |
||||||
|
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||||
|
t.dialAddress = dialAddr |
||||||
|
t.remoteAddr = addr |
||||||
|
t.hostKeyCallback = config.HostKeyCallback |
||||||
|
if config.HostKeyAlgorithms != nil { |
||||||
|
t.hostKeyAlgorithms = config.HostKeyAlgorithms |
||||||
|
} else { |
||||||
|
t.hostKeyAlgorithms = supportedHostKeyAlgos |
||||||
|
} |
||||||
|
go t.readLoop() |
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { |
||||||
|
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||||
|
t.hostKeys = config.hostKeys |
||||||
|
go t.readLoop() |
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) getSessionID() []byte { |
||||||
|
return t.conn.getSessionID() |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) id() string { |
||||||
|
if len(t.hostKeys) > 0 { |
||||||
|
return "server" |
||||||
|
} |
||||||
|
return "client" |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) readPacket() ([]byte, error) { |
||||||
|
p, ok := <-t.incoming |
||||||
|
if !ok { |
||||||
|
return nil, t.readError |
||||||
|
} |
||||||
|
return p, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) readLoop() { |
||||||
|
for { |
||||||
|
p, err := t.readOnePacket() |
||||||
|
if err != nil { |
||||||
|
t.readError = err |
||||||
|
close(t.incoming) |
||||||
|
break |
||||||
|
} |
||||||
|
if p[0] == msgIgnore || p[0] == msgDebug { |
||||||
|
continue |
||||||
|
} |
||||||
|
t.incoming <- p |
||||||
|
} |
||||||
|
|
||||||
|
// If we can't read, declare the writing part dead too.
|
||||||
|
t.mu.Lock() |
||||||
|
defer t.mu.Unlock() |
||||||
|
if t.writeError == nil { |
||||||
|
t.writeError = t.readError |
||||||
|
} |
||||||
|
t.cond.Broadcast() |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) readOnePacket() ([]byte, error) { |
||||||
|
if t.readSinceKex > t.config.RekeyThreshold { |
||||||
|
if err := t.requestKeyChange(); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
p, err := t.conn.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
t.readSinceKex += uint64(len(p)) |
||||||
|
if debugHandshake { |
||||||
|
msg, err := decode(p) |
||||||
|
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) |
||||||
|
} |
||||||
|
if p[0] != msgKexInit { |
||||||
|
return p, nil |
||||||
|
} |
||||||
|
err = t.enterKeyExchange(p) |
||||||
|
|
||||||
|
t.mu.Lock() |
||||||
|
if err != nil { |
||||||
|
// drop connection
|
||||||
|
t.conn.Close() |
||||||
|
t.writeError = err |
||||||
|
} |
||||||
|
|
||||||
|
if debugHandshake { |
||||||
|
log.Printf("%s exited key exchange, err %v", t.id(), err) |
||||||
|
} |
||||||
|
|
||||||
|
// Unblock writers.
|
||||||
|
t.sentInitMsg = nil |
||||||
|
t.sentInitPacket = nil |
||||||
|
t.cond.Broadcast() |
||||||
|
t.writtenSinceKex = 0 |
||||||
|
t.mu.Unlock() |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
t.readSinceKex = 0 |
||||||
|
return []byte{msgNewKeys}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// sendKexInit sends a key change message, and returns the message
|
||||||
|
// that was sent. After initiating the key change, all writes will be
|
||||||
|
// blocked until the change is done, and a failed key change will
|
||||||
|
// close the underlying transport. This function is safe for
|
||||||
|
// concurrent use by multiple goroutines.
|
||||||
|
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { |
||||||
|
t.mu.Lock() |
||||||
|
defer t.mu.Unlock() |
||||||
|
return t.sendKexInitLocked() |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) requestKeyChange() error { |
||||||
|
_, _, err := t.sendKexInit() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// sendKexInitLocked sends a key change message. t.mu must be locked
|
||||||
|
// while this happens.
|
||||||
|
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { |
||||||
|
// kexInits may be sent either in response to the other side,
|
||||||
|
// or because our side wants to initiate a key change, so we
|
||||||
|
// may have already sent a kexInit. In that case, don't send a
|
||||||
|
// second kexInit.
|
||||||
|
if t.sentInitMsg != nil { |
||||||
|
return t.sentInitMsg, t.sentInitPacket, nil |
||||||
|
} |
||||||
|
msg := &kexInitMsg{ |
||||||
|
KexAlgos: t.config.KeyExchanges, |
||||||
|
CiphersClientServer: t.config.Ciphers, |
||||||
|
CiphersServerClient: t.config.Ciphers, |
||||||
|
MACsClientServer: t.config.MACs, |
||||||
|
MACsServerClient: t.config.MACs, |
||||||
|
CompressionClientServer: supportedCompressions, |
||||||
|
CompressionServerClient: supportedCompressions, |
||||||
|
} |
||||||
|
io.ReadFull(rand.Reader, msg.Cookie[:]) |
||||||
|
|
||||||
|
if len(t.hostKeys) > 0 { |
||||||
|
for _, k := range t.hostKeys { |
||||||
|
msg.ServerHostKeyAlgos = append( |
||||||
|
msg.ServerHostKeyAlgos, k.PublicKey().Type()) |
||||||
|
} |
||||||
|
} else { |
||||||
|
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms |
||||||
|
} |
||||||
|
packet := Marshal(msg) |
||||||
|
|
||||||
|
// writePacket destroys the contents, so save a copy.
|
||||||
|
packetCopy := make([]byte, len(packet)) |
||||||
|
copy(packetCopy, packet) |
||||||
|
|
||||||
|
if err := t.conn.writePacket(packetCopy); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
t.sentInitMsg = msg |
||||||
|
t.sentInitPacket = packet |
||||||
|
return msg, packet, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) writePacket(p []byte) error { |
||||||
|
t.mu.Lock() |
||||||
|
defer t.mu.Unlock() |
||||||
|
|
||||||
|
if t.writtenSinceKex > t.config.RekeyThreshold { |
||||||
|
t.sendKexInitLocked() |
||||||
|
} |
||||||
|
for t.sentInitMsg != nil && t.writeError == nil { |
||||||
|
t.cond.Wait() |
||||||
|
} |
||||||
|
if t.writeError != nil { |
||||||
|
return t.writeError |
||||||
|
} |
||||||
|
t.writtenSinceKex += uint64(len(p)) |
||||||
|
|
||||||
|
switch p[0] { |
||||||
|
case msgKexInit: |
||||||
|
return errors.New("ssh: only handshakeTransport can send kexInit") |
||||||
|
case msgNewKeys: |
||||||
|
return errors.New("ssh: only handshakeTransport can send newKeys") |
||||||
|
default: |
||||||
|
return t.conn.writePacket(p) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) Close() error { |
||||||
|
return t.conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// enterKeyExchange runs the key exchange.
|
||||||
|
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { |
||||||
|
if debugHandshake { |
||||||
|
log.Printf("%s entered key exchange", t.id()) |
||||||
|
} |
||||||
|
myInit, myInitPacket, err := t.sendKexInit() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
otherInit := &kexInitMsg{} |
||||||
|
if err := Unmarshal(otherInitPacket, otherInit); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
magics := handshakeMagics{ |
||||||
|
clientVersion: t.clientVersion, |
||||||
|
serverVersion: t.serverVersion, |
||||||
|
clientKexInit: otherInitPacket, |
||||||
|
serverKexInit: myInitPacket, |
||||||
|
} |
||||||
|
|
||||||
|
clientInit := otherInit |
||||||
|
serverInit := myInit |
||||||
|
if len(t.hostKeys) == 0 { |
||||||
|
clientInit = myInit |
||||||
|
serverInit = otherInit |
||||||
|
|
||||||
|
magics.clientKexInit = myInitPacket |
||||||
|
magics.serverKexInit = otherInitPacket |
||||||
|
} |
||||||
|
|
||||||
|
algs, err := findAgreedAlgorithms(clientInit, serverInit) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// We don't send FirstKexFollows, but we handle receiving it.
|
||||||
|
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { |
||||||
|
// other side sent a kex message for the wrong algorithm,
|
||||||
|
// which we have to ignore.
|
||||||
|
if _, err := t.conn.readPacket(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
kex, ok := kexAlgoMap[algs.kex] |
||||||
|
if !ok { |
||||||
|
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) |
||||||
|
} |
||||||
|
|
||||||
|
var result *kexResult |
||||||
|
if len(t.hostKeys) > 0 { |
||||||
|
result, err = t.server(kex, algs, &magics) |
||||||
|
} else { |
||||||
|
result, err = t.client(kex, algs, &magics) |
||||||
|
} |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
t.conn.prepareKeyChange(algs, result) |
||||||
|
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if packet, err := t.conn.readPacket(); err != nil { |
||||||
|
return err |
||||||
|
} else if packet[0] != msgNewKeys { |
||||||
|
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||||
|
var hostKey Signer |
||||||
|
for _, k := range t.hostKeys { |
||||||
|
if algs.hostKey == k.PublicKey().Type() { |
||||||
|
hostKey = k |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) |
||||||
|
return r, err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||||
|
result, err := kex.Client(t.conn, t.config.Rand, magics) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
hostKey, err := ParsePublicKey(result.HostKey) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if err := verifyHostKeySignature(hostKey, result); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if t.hostKeyCallback != nil { |
||||||
|
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return result, nil |
||||||
|
} |
@ -0,0 +1,415 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
"runtime" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
type testChecker struct { |
||||||
|
calls []string |
||||||
|
} |
||||||
|
|
||||||
|
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||||
|
if dialAddr == "bad" { |
||||||
|
return fmt.Errorf("dialAddr is bad") |
||||||
|
} |
||||||
|
|
||||||
|
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { |
||||||
|
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) |
||||||
|
} |
||||||
|
|
||||||
|
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||||
|
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||||
|
// a write.)
|
||||||
|
func netPipe() (net.Conn, net.Conn, error) { |
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
defer listener.Close() |
||||||
|
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c2, err := listener.Accept() |
||||||
|
if err != nil { |
||||||
|
c1.Close() |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return c1, c2, nil |
||||||
|
} |
||||||
|
|
||||||
|
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { |
||||||
|
a, b, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
trC := newTransport(a, rand.Reader, true) |
||||||
|
trS := newTransport(b, rand.Reader, false) |
||||||
|
clientConf.SetDefaults() |
||||||
|
|
||||||
|
v := []byte("version") |
||||||
|
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) |
||||||
|
|
||||||
|
serverConf := &ServerConfig{} |
||||||
|
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||||
|
serverConf.AddHostKey(testSigners["rsa"]) |
||||||
|
serverConf.SetDefaults() |
||||||
|
server = newServerTransport(trS, v, v, serverConf) |
||||||
|
|
||||||
|
return client, server, nil |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeBasic(t *testing.T) { |
||||||
|
if runtime.GOOS == "plan9" { |
||||||
|
t.Skip("see golang.org/issue/7237") |
||||||
|
} |
||||||
|
checker := &testChecker{} |
||||||
|
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("handshakePair: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
defer trC.Close() |
||||||
|
defer trS.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
// Client writes a bunch of stuff, and does a key
|
||||||
|
// change in the middle. This should not confuse the
|
||||||
|
// handshake in progress
|
||||||
|
for i := 0; i < 10; i++ { |
||||||
|
p := []byte{msgRequestSuccess, byte(i)} |
||||||
|
if err := trC.writePacket(p); err != nil { |
||||||
|
t.Fatalf("sendPacket: %v", err) |
||||||
|
} |
||||||
|
if i == 5 { |
||||||
|
// halfway through, we request a key change.
|
||||||
|
_, _, err := trC.sendKexInit() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("sendKexInit: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
trC.Close() |
||||||
|
}() |
||||||
|
|
||||||
|
// Server checks that client messages come in cleanly
|
||||||
|
i := 0 |
||||||
|
for { |
||||||
|
p, err := trS.readPacket() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
if p[0] == msgNewKeys { |
||||||
|
continue |
||||||
|
} |
||||||
|
want := []byte{msgRequestSuccess, byte(i)} |
||||||
|
if bytes.Compare(p, want) != 0 { |
||||||
|
t.Errorf("message %d: got %q, want %q", i, p, want) |
||||||
|
} |
||||||
|
i++ |
||||||
|
} |
||||||
|
if i != 10 { |
||||||
|
t.Errorf("received %d messages, want 10.", i) |
||||||
|
} |
||||||
|
|
||||||
|
// If all went well, we registered exactly 1 key change.
|
||||||
|
if len(checker.calls) != 1 { |
||||||
|
t.Fatalf("got %d host key checks, want 1", len(checker.calls)) |
||||||
|
} |
||||||
|
|
||||||
|
pub := testSigners["ecdsa"].PublicKey() |
||||||
|
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) |
||||||
|
if want != checker.calls[0] { |
||||||
|
t.Errorf("got %q want %q for host key check", checker.calls[0], want) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeError(t *testing.T) { |
||||||
|
checker := &testChecker{} |
||||||
|
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("handshakePair: %v", err) |
||||||
|
} |
||||||
|
defer trC.Close() |
||||||
|
defer trS.Close() |
||||||
|
|
||||||
|
// send a packet
|
||||||
|
packet := []byte{msgRequestSuccess, 42} |
||||||
|
if err := trC.writePacket(packet); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Now request a key change.
|
||||||
|
_, _, err = trC.sendKexInit() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("sendKexInit: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// the key change will fail, and afterwards we can't write.
|
||||||
|
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { |
||||||
|
t.Errorf("writePacket after botched rekey succeeded.") |
||||||
|
} |
||||||
|
|
||||||
|
readback, err := trS.readPacket() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("server closed too soon: %v", err) |
||||||
|
} |
||||||
|
if bytes.Compare(readback, packet) != 0 { |
||||||
|
t.Errorf("got %q want %q", readback, packet) |
||||||
|
} |
||||||
|
readback, err = trS.readPacket() |
||||||
|
if err == nil { |
||||||
|
t.Errorf("got a message %q after failed key change", readback) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeTwice(t *testing.T) { |
||||||
|
checker := &testChecker{} |
||||||
|
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("handshakePair: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
defer trC.Close() |
||||||
|
defer trS.Close() |
||||||
|
|
||||||
|
// send a packet
|
||||||
|
packet := make([]byte, 5) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
if err := trC.writePacket(packet); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Now request a key change.
|
||||||
|
_, _, err = trC.sendKexInit() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("sendKexInit: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Send another packet. Use a fresh one, since writePacket destroys.
|
||||||
|
packet = make([]byte, 5) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
if err := trC.writePacket(packet); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 2nd key change.
|
||||||
|
_, _, err = trC.sendKexInit() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("sendKexInit: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
packet = make([]byte, 5) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
if err := trC.writePacket(packet); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
packet = make([]byte, 5) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
for i := 0; i < 5; i++ { |
||||||
|
msg, err := trS.readPacket() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("server closed too soon: %v", err) |
||||||
|
} |
||||||
|
if msg[0] == msgNewKeys { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if bytes.Compare(msg, packet) != 0 { |
||||||
|
t.Errorf("packet %d: got %q want %q", i, msg, packet) |
||||||
|
} |
||||||
|
} |
||||||
|
if len(checker.calls) != 2 { |
||||||
|
t.Errorf("got %d key changes, want 2", len(checker.calls)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeAutoRekeyWrite(t *testing.T) { |
||||||
|
checker := &testChecker{} |
||||||
|
clientConf := &ClientConfig{HostKeyCallback: checker.Check} |
||||||
|
clientConf.RekeyThreshold = 500 |
||||||
|
trC, trS, err := handshakePair(clientConf, "addr") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("handshakePair: %v", err) |
||||||
|
} |
||||||
|
defer trC.Close() |
||||||
|
defer trS.Close() |
||||||
|
|
||||||
|
for i := 0; i < 5; i++ { |
||||||
|
packet := make([]byte, 251) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
if err := trC.writePacket(packet); err != nil { |
||||||
|
t.Errorf("writePacket: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
j := 0 |
||||||
|
for ; j < 5; j++ { |
||||||
|
_, err := trS.readPacket() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if j != 5 { |
||||||
|
t.Errorf("got %d, want 5 messages", j) |
||||||
|
} |
||||||
|
|
||||||
|
if len(checker.calls) != 2 { |
||||||
|
t.Errorf("got %d key changes, wanted 2", len(checker.calls)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type syncChecker struct { |
||||||
|
called chan int |
||||||
|
} |
||||||
|
|
||||||
|
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||||
|
t.called <- 1 |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeAutoRekeyRead(t *testing.T) { |
||||||
|
sync := &syncChecker{make(chan int, 2)} |
||||||
|
clientConf := &ClientConfig{ |
||||||
|
HostKeyCallback: sync.Check, |
||||||
|
} |
||||||
|
clientConf.RekeyThreshold = 500 |
||||||
|
|
||||||
|
trC, trS, err := handshakePair(clientConf, "addr") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("handshakePair: %v", err) |
||||||
|
} |
||||||
|
defer trC.Close() |
||||||
|
defer trS.Close() |
||||||
|
|
||||||
|
packet := make([]byte, 501) |
||||||
|
packet[0] = msgRequestSuccess |
||||||
|
if err := trS.writePacket(packet); err != nil { |
||||||
|
t.Fatalf("writePacket: %v", err) |
||||||
|
} |
||||||
|
// While we read out the packet, a key change will be
|
||||||
|
// initiated.
|
||||||
|
if _, err := trC.readPacket(); err != nil { |
||||||
|
t.Fatalf("readPacket(client): %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
<-sync.called |
||||||
|
} |
||||||
|
|
||||||
|
// errorKeyingTransport generates errors after a given number of
|
||||||
|
// read/write operations.
|
||||||
|
type errorKeyingTransport struct { |
||||||
|
packetConn |
||||||
|
readLeft, writeLeft int |
||||||
|
} |
||||||
|
|
||||||
|
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
func (n *errorKeyingTransport) getSessionID() []byte { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (n *errorKeyingTransport) writePacket(packet []byte) error { |
||||||
|
if n.writeLeft == 0 { |
||||||
|
n.Close() |
||||||
|
return errors.New("barf") |
||||||
|
} |
||||||
|
|
||||||
|
n.writeLeft-- |
||||||
|
return n.packetConn.writePacket(packet) |
||||||
|
} |
||||||
|
|
||||||
|
func (n *errorKeyingTransport) readPacket() ([]byte, error) { |
||||||
|
if n.readLeft == 0 { |
||||||
|
n.Close() |
||||||
|
return nil, errors.New("barf") |
||||||
|
} |
||||||
|
|
||||||
|
n.readLeft-- |
||||||
|
return n.packetConn.readPacket() |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeErrorHandlingRead(t *testing.T) { |
||||||
|
for i := 0; i < 20; i++ { |
||||||
|
testHandshakeErrorHandlingN(t, i, -1) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHandshakeErrorHandlingWrite(t *testing.T) { |
||||||
|
for i := 0; i < 20; i++ { |
||||||
|
testHandshakeErrorHandlingN(t, -1, i) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
|
||||||
|
// handshakeTransport deadlocks, the go runtime will detect it and
|
||||||
|
// panic.
|
||||||
|
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { |
||||||
|
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) |
||||||
|
|
||||||
|
a, b := memPipe() |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
|
||||||
|
key := testSigners["ecdsa"] |
||||||
|
serverConf := Config{RekeyThreshold: minRekeyThreshold} |
||||||
|
serverConf.SetDefaults() |
||||||
|
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) |
||||||
|
serverConn.hostKeys = []Signer{key} |
||||||
|
go serverConn.readLoop() |
||||||
|
|
||||||
|
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} |
||||||
|
clientConf.SetDefaults() |
||||||
|
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) |
||||||
|
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} |
||||||
|
go clientConn.readLoop() |
||||||
|
|
||||||
|
var wg sync.WaitGroup |
||||||
|
wg.Add(4) |
||||||
|
|
||||||
|
for _, hs := range []packetConn{serverConn, clientConn} { |
||||||
|
go func(c packetConn) { |
||||||
|
for { |
||||||
|
err := c.writePacket(msg) |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
wg.Done() |
||||||
|
}(hs) |
||||||
|
go func(c packetConn) { |
||||||
|
for { |
||||||
|
_, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
wg.Done() |
||||||
|
}(hs) |
||||||
|
} |
||||||
|
|
||||||
|
wg.Wait() |
||||||
|
} |
@ -0,0 +1,526 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto" |
||||||
|
"crypto/ecdsa" |
||||||
|
"crypto/elliptic" |
||||||
|
"crypto/subtle" |
||||||
|
"crypto/rand" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"math/big" |
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" |
||||||
|
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" |
||||||
|
kexAlgoECDH256 = "ecdh-sha2-nistp256" |
||||||
|
kexAlgoECDH384 = "ecdh-sha2-nistp384" |
||||||
|
kexAlgoECDH521 = "ecdh-sha2-nistp521" |
||||||
|
kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" |
||||||
|
) |
||||||
|
|
||||||
|
// kexResult captures the outcome of a key exchange.
|
||||||
|
type kexResult struct { |
||||||
|
// Session hash. See also RFC 4253, section 8.
|
||||||
|
H []byte |
||||||
|
|
||||||
|
// Shared secret. See also RFC 4253, section 8.
|
||||||
|
K []byte |
||||||
|
|
||||||
|
// Host key as hashed into H.
|
||||||
|
HostKey []byte |
||||||
|
|
||||||
|
// Signature of H.
|
||||||
|
Signature []byte |
||||||
|
|
||||||
|
// A cryptographic hash function that matches the security
|
||||||
|
// level of the key exchange algorithm. It is used for
|
||||||
|
// calculating H, and for deriving keys from H and K.
|
||||||
|
Hash crypto.Hash |
||||||
|
|
||||||
|
// The session ID, which is the first H computed. This is used
|
||||||
|
// to signal data inside transport.
|
||||||
|
SessionID []byte |
||||||
|
} |
||||||
|
|
||||||
|
// handshakeMagics contains data that is always included in the
|
||||||
|
// session hash.
|
||||||
|
type handshakeMagics struct { |
||||||
|
clientVersion, serverVersion []byte |
||||||
|
clientKexInit, serverKexInit []byte |
||||||
|
} |
||||||
|
|
||||||
|
func (m *handshakeMagics) write(w io.Writer) { |
||||||
|
writeString(w, m.clientVersion) |
||||||
|
writeString(w, m.serverVersion) |
||||||
|
writeString(w, m.clientKexInit) |
||||||
|
writeString(w, m.serverKexInit) |
||||||
|
} |
||||||
|
|
||||||
|
// kexAlgorithm abstracts different key exchange algorithms.
|
||||||
|
type kexAlgorithm interface { |
||||||
|
// Server runs server-side key agreement, signing the result
|
||||||
|
// with a hostkey.
|
||||||
|
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) |
||||||
|
|
||||||
|
// Client runs the client-side key agreement. Caller is
|
||||||
|
// responsible for verifying the host key signature.
|
||||||
|
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) |
||||||
|
} |
||||||
|
|
||||||
|
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||||
|
type dhGroup struct { |
||||||
|
g, p *big.Int |
||||||
|
} |
||||||
|
|
||||||
|
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { |
||||||
|
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 { |
||||||
|
return nil, errors.New("ssh: DH parameter out of bounds") |
||||||
|
} |
||||||
|
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||||
|
hashFunc := crypto.SHA1 |
||||||
|
|
||||||
|
x, err := rand.Int(randSource, group.p) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
X := new(big.Int).Exp(group.g, x, group.p) |
||||||
|
kexDHInit := kexDHInitMsg{ |
||||||
|
X: X, |
||||||
|
} |
||||||
|
if err := c.writePacket(Marshal(&kexDHInit)); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var kexDHReply kexDHReplyMsg |
||||||
|
if err = Unmarshal(packet, &kexDHReply); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
kInt, err := group.diffieHellman(kexDHReply.Y, x) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
h := hashFunc.New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, kexDHReply.HostKey) |
||||||
|
writeInt(h, X) |
||||||
|
writeInt(h, kexDHReply.Y) |
||||||
|
K := make([]byte, intLength(kInt)) |
||||||
|
marshalInt(K, kInt) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
return &kexResult{ |
||||||
|
H: h.Sum(nil), |
||||||
|
K: K, |
||||||
|
HostKey: kexDHReply.HostKey, |
||||||
|
Signature: kexDHReply.Signature, |
||||||
|
Hash: crypto.SHA1, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||||
|
hashFunc := crypto.SHA1 |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
var kexDHInit kexDHInitMsg |
||||||
|
if err = Unmarshal(packet, &kexDHInit); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
y, err := rand.Int(randSource, group.p) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
Y := new(big.Int).Exp(group.g, y, group.p) |
||||||
|
kInt, err := group.diffieHellman(kexDHInit.X, y) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
hostKeyBytes := priv.PublicKey().Marshal() |
||||||
|
|
||||||
|
h := hashFunc.New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, hostKeyBytes) |
||||||
|
writeInt(h, kexDHInit.X) |
||||||
|
writeInt(h, Y) |
||||||
|
|
||||||
|
K := make([]byte, intLength(kInt)) |
||||||
|
marshalInt(K, kInt) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
H := h.Sum(nil) |
||||||
|
|
||||||
|
// H is already a hash, but the hostkey signing will apply its
|
||||||
|
// own key-specific hash algorithm.
|
||||||
|
sig, err := signAndMarshal(priv, randSource, H) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
kexDHReply := kexDHReplyMsg{ |
||||||
|
HostKey: hostKeyBytes, |
||||||
|
Y: Y, |
||||||
|
Signature: sig, |
||||||
|
} |
||||||
|
packet = Marshal(&kexDHReply) |
||||||
|
|
||||||
|
err = c.writePacket(packet) |
||||||
|
return &kexResult{ |
||||||
|
H: H, |
||||||
|
K: K, |
||||||
|
HostKey: hostKeyBytes, |
||||||
|
Signature: sig, |
||||||
|
Hash: crypto.SHA1, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
||||||
|
// described in RFC 5656, section 4.
|
||||||
|
type ecdh struct { |
||||||
|
curve elliptic.Curve |
||||||
|
} |
||||||
|
|
||||||
|
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||||
|
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
kexInit := kexECDHInitMsg{ |
||||||
|
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), |
||||||
|
} |
||||||
|
|
||||||
|
serialized := Marshal(&kexInit) |
||||||
|
if err := c.writePacket(serialized); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var reply kexECDHReplyMsg |
||||||
|
if err = Unmarshal(packet, &reply); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// generate shared secret
|
||||||
|
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) |
||||||
|
|
||||||
|
h := ecHash(kex.curve).New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, reply.HostKey) |
||||||
|
writeString(h, kexInit.ClientPubKey) |
||||||
|
writeString(h, reply.EphemeralPubKey) |
||||||
|
K := make([]byte, intLength(secret)) |
||||||
|
marshalInt(K, secret) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
return &kexResult{ |
||||||
|
H: h.Sum(nil), |
||||||
|
K: K, |
||||||
|
HostKey: reply.HostKey, |
||||||
|
Signature: reply.Signature, |
||||||
|
Hash: ecHash(kex.curve), |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// unmarshalECKey parses and checks an EC key.
|
||||||
|
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { |
||||||
|
x, y = elliptic.Unmarshal(curve, pubkey) |
||||||
|
if x == nil { |
||||||
|
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") |
||||||
|
} |
||||||
|
if !validateECPublicKey(curve, x, y) { |
||||||
|
return nil, nil, errors.New("ssh: public key not on curve") |
||||||
|
} |
||||||
|
return x, y, nil |
||||||
|
} |
||||||
|
|
||||||
|
// validateECPublicKey checks that the point is a valid public key for
|
||||||
|
// the given curve. See [SEC1], 3.2.2
|
||||||
|
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { |
||||||
|
if x.Sign() == 0 && y.Sign() == 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
if x.Cmp(curve.Params().P) >= 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
if y.Cmp(curve.Params().P) >= 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
if !curve.IsOnCurve(x, y) { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// We don't check if N * PubKey == 0, since
|
||||||
|
//
|
||||||
|
// - the NIST curves have cofactor = 1, so this is implicit.
|
||||||
|
// (We don't foresee an implementation that supports non NIST
|
||||||
|
// curves)
|
||||||
|
//
|
||||||
|
// - for ephemeral keys, we don't need to worry about small
|
||||||
|
// subgroup attacks.
|
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var kexECDHInit kexECDHInitMsg |
||||||
|
if err = Unmarshal(packet, &kexECDHInit); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// We could cache this key across multiple users/multiple
|
||||||
|
// connection attempts, but the benefit is small. OpenSSH
|
||||||
|
// generates a new key for each incoming connection.
|
||||||
|
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
hostKeyBytes := priv.PublicKey().Marshal() |
||||||
|
|
||||||
|
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) |
||||||
|
|
||||||
|
// generate shared secret
|
||||||
|
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) |
||||||
|
|
||||||
|
h := ecHash(kex.curve).New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, hostKeyBytes) |
||||||
|
writeString(h, kexECDHInit.ClientPubKey) |
||||||
|
writeString(h, serializedEphKey) |
||||||
|
|
||||||
|
K := make([]byte, intLength(secret)) |
||||||
|
marshalInt(K, secret) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
H := h.Sum(nil) |
||||||
|
|
||||||
|
// H is already a hash, but the hostkey signing will apply its
|
||||||
|
// own key-specific hash algorithm.
|
||||||
|
sig, err := signAndMarshal(priv, rand, H) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
reply := kexECDHReplyMsg{ |
||||||
|
EphemeralPubKey: serializedEphKey, |
||||||
|
HostKey: hostKeyBytes, |
||||||
|
Signature: sig, |
||||||
|
} |
||||||
|
|
||||||
|
serialized := Marshal(&reply) |
||||||
|
if err := c.writePacket(serialized); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &kexResult{ |
||||||
|
H: H, |
||||||
|
K: K, |
||||||
|
HostKey: reply.HostKey, |
||||||
|
Signature: sig, |
||||||
|
Hash: ecHash(kex.curve), |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
var kexAlgoMap = map[string]kexAlgorithm{} |
||||||
|
|
||||||
|
func init() { |
||||||
|
// This is the group called diffie-hellman-group1-sha1 in RFC
|
||||||
|
// 4253 and Oakley Group 2 in RFC 2409.
|
||||||
|
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) |
||||||
|
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ |
||||||
|
g: new(big.Int).SetInt64(2), |
||||||
|
p: p, |
||||||
|
} |
||||||
|
|
||||||
|
// This is the group called diffie-hellman-group14-sha1 in RFC
|
||||||
|
// 4253 and Oakley Group 14 in RFC 3526.
|
||||||
|
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) |
||||||
|
|
||||||
|
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ |
||||||
|
g: new(big.Int).SetInt64(2), |
||||||
|
p: p, |
||||||
|
} |
||||||
|
|
||||||
|
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} |
||||||
|
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} |
||||||
|
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} |
||||||
|
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} |
||||||
|
} |
||||||
|
|
||||||
|
// curve25519sha256 implements the curve25519-sha256@libssh.org key
|
||||||
|
// agreement protocol, as described in
|
||||||
|
// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
|
||||||
|
type curve25519sha256 struct{} |
||||||
|
|
||||||
|
type curve25519KeyPair struct { |
||||||
|
priv [32]byte |
||||||
|
pub [32]byte |
||||||
|
} |
||||||
|
|
||||||
|
func (kp *curve25519KeyPair) generate(rand io.Reader) error { |
||||||
|
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
curve25519.ScalarBaseMult(&kp.pub, &kp.priv) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
||||||
|
// convenient to compare against in order to reject curve25519 points with the
|
||||||
|
// wrong order.
|
||||||
|
var curve25519Zeros [32]byte |
||||||
|
|
||||||
|
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||||
|
var kp curve25519KeyPair |
||||||
|
if err := kp.generate(rand); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var reply kexECDHReplyMsg |
||||||
|
if err = Unmarshal(packet, &reply); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(reply.EphemeralPubKey) != 32 { |
||||||
|
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||||
|
} |
||||||
|
|
||||||
|
var servPub, secret [32]byte |
||||||
|
copy(servPub[:], reply.EphemeralPubKey) |
||||||
|
curve25519.ScalarMult(&secret, &kp.priv, &servPub) |
||||||
|
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||||
|
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||||
|
} |
||||||
|
|
||||||
|
h := crypto.SHA256.New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, reply.HostKey) |
||||||
|
writeString(h, kp.pub[:]) |
||||||
|
writeString(h, reply.EphemeralPubKey) |
||||||
|
|
||||||
|
kInt := new(big.Int).SetBytes(secret[:]) |
||||||
|
K := make([]byte, intLength(kInt)) |
||||||
|
marshalInt(K, kInt) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
return &kexResult{ |
||||||
|
H: h.Sum(nil), |
||||||
|
K: K, |
||||||
|
HostKey: reply.HostKey, |
||||||
|
Signature: reply.Signature, |
||||||
|
Hash: crypto.SHA256, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||||
|
packet, err := c.readPacket() |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
var kexInit kexECDHInitMsg |
||||||
|
if err = Unmarshal(packet, &kexInit); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if len(kexInit.ClientPubKey) != 32 { |
||||||
|
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||||
|
} |
||||||
|
|
||||||
|
var kp curve25519KeyPair |
||||||
|
if err := kp.generate(rand); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var clientPub, secret [32]byte |
||||||
|
copy(clientPub[:], kexInit.ClientPubKey) |
||||||
|
curve25519.ScalarMult(&secret, &kp.priv, &clientPub) |
||||||
|
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||||
|
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||||
|
} |
||||||
|
|
||||||
|
hostKeyBytes := priv.PublicKey().Marshal() |
||||||
|
|
||||||
|
h := crypto.SHA256.New() |
||||||
|
magics.write(h) |
||||||
|
writeString(h, hostKeyBytes) |
||||||
|
writeString(h, kexInit.ClientPubKey) |
||||||
|
writeString(h, kp.pub[:]) |
||||||
|
|
||||||
|
kInt := new(big.Int).SetBytes(secret[:]) |
||||||
|
K := make([]byte, intLength(kInt)) |
||||||
|
marshalInt(K, kInt) |
||||||
|
h.Write(K) |
||||||
|
|
||||||
|
H := h.Sum(nil) |
||||||
|
|
||||||
|
sig, err := signAndMarshal(priv, rand, H) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
reply := kexECDHReplyMsg{ |
||||||
|
EphemeralPubKey: kp.pub[:], |
||||||
|
HostKey: hostKeyBytes, |
||||||
|
Signature: sig, |
||||||
|
} |
||||||
|
if err := c.writePacket(Marshal(&reply)); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &kexResult{ |
||||||
|
H: H, |
||||||
|
K: K, |
||||||
|
HostKey: hostKeyBytes, |
||||||
|
Signature: sig, |
||||||
|
Hash: crypto.SHA256, |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,50 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
// Key exchange tests.
|
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"reflect" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestKexes(t *testing.T) { |
||||||
|
type kexResultErr struct { |
||||||
|
result *kexResult |
||||||
|
err error |
||||||
|
} |
||||||
|
|
||||||
|
for name, kex := range kexAlgoMap { |
||||||
|
a, b := memPipe() |
||||||
|
|
||||||
|
s := make(chan kexResultErr, 1) |
||||||
|
c := make(chan kexResultErr, 1) |
||||||
|
var magics handshakeMagics |
||||||
|
go func() { |
||||||
|
r, e := kex.Client(a, rand.Reader, &magics) |
||||||
|
a.Close() |
||||||
|
c <- kexResultErr{r, e} |
||||||
|
}() |
||||||
|
go func() { |
||||||
|
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) |
||||||
|
b.Close() |
||||||
|
s <- kexResultErr{r, e} |
||||||
|
}() |
||||||
|
|
||||||
|
clientRes := <-c |
||||||
|
serverRes := <-s |
||||||
|
if clientRes.err != nil { |
||||||
|
t.Errorf("client: %v", clientRes.err) |
||||||
|
} |
||||||
|
if serverRes.err != nil { |
||||||
|
t.Errorf("server: %v", serverRes.err) |
||||||
|
} |
||||||
|
if !reflect.DeepEqual(clientRes.result, serverRes.result) { |
||||||
|
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,628 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto" |
||||||
|
"crypto/dsa" |
||||||
|
"crypto/ecdsa" |
||||||
|
"crypto/elliptic" |
||||||
|
"crypto/rsa" |
||||||
|
"crypto/x509" |
||||||
|
"encoding/asn1" |
||||||
|
"encoding/base64" |
||||||
|
"encoding/pem" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"math/big" |
||||||
|
) |
||||||
|
|
||||||
|
// These constants represent the algorithm names for key types supported by this
|
||||||
|
// package.
|
||||||
|
const ( |
||||||
|
KeyAlgoRSA = "ssh-rsa" |
||||||
|
KeyAlgoDSA = "ssh-dss" |
||||||
|
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" |
||||||
|
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" |
||||||
|
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" |
||||||
|
) |
||||||
|
|
||||||
|
// parsePubKey parses a public key of the given algorithm.
|
||||||
|
// Use ParsePublicKey for keys with prepended algorithm.
|
||||||
|
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { |
||||||
|
switch algo { |
||||||
|
case KeyAlgoRSA: |
||||||
|
return parseRSA(in) |
||||||
|
case KeyAlgoDSA: |
||||||
|
return parseDSA(in) |
||||||
|
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: |
||||||
|
return parseECDSA(in) |
||||||
|
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||||
|
cert, err := parseCert(in, certToPrivAlgo(algo)) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
return cert, nil, nil |
||||||
|
} |
||||||
|
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
|
||||||
|
// (see sshd(8) manual page) once the options and key type fields have been
|
||||||
|
// removed.
|
||||||
|
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { |
||||||
|
in = bytes.TrimSpace(in) |
||||||
|
|
||||||
|
i := bytes.IndexAny(in, " \t") |
||||||
|
if i == -1 { |
||||||
|
i = len(in) |
||||||
|
} |
||||||
|
base64Key := in[:i] |
||||||
|
|
||||||
|
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) |
||||||
|
n, err := base64.StdEncoding.Decode(key, base64Key) |
||||||
|
if err != nil { |
||||||
|
return nil, "", err |
||||||
|
} |
||||||
|
key = key[:n] |
||||||
|
out, err = ParsePublicKey(key) |
||||||
|
if err != nil { |
||||||
|
return nil, "", err |
||||||
|
} |
||||||
|
comment = string(bytes.TrimSpace(in[i:])) |
||||||
|
return out, comment, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ParseAuthorizedKeys parses a public key from an authorized_keys
|
||||||
|
// file used in OpenSSH according to the sshd(8) manual page.
|
||||||
|
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { |
||||||
|
for len(in) > 0 { |
||||||
|
end := bytes.IndexByte(in, '\n') |
||||||
|
if end != -1 { |
||||||
|
rest = in[end+1:] |
||||||
|
in = in[:end] |
||||||
|
} else { |
||||||
|
rest = nil |
||||||
|
} |
||||||
|
|
||||||
|
end = bytes.IndexByte(in, '\r') |
||||||
|
if end != -1 { |
||||||
|
in = in[:end] |
||||||
|
} |
||||||
|
|
||||||
|
in = bytes.TrimSpace(in) |
||||||
|
if len(in) == 0 || in[0] == '#' { |
||||||
|
in = rest |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
i := bytes.IndexAny(in, " \t") |
||||||
|
if i == -1 { |
||||||
|
in = rest |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||||
|
return out, comment, options, rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
// No key type recognised. Maybe there's an options field at
|
||||||
|
// the beginning.
|
||||||
|
var b byte |
||||||
|
inQuote := false |
||||||
|
var candidateOptions []string |
||||||
|
optionStart := 0 |
||||||
|
for i, b = range in { |
||||||
|
isEnd := !inQuote && (b == ' ' || b == '\t') |
||||||
|
if (b == ',' && !inQuote) || isEnd { |
||||||
|
if i-optionStart > 0 { |
||||||
|
candidateOptions = append(candidateOptions, string(in[optionStart:i])) |
||||||
|
} |
||||||
|
optionStart = i + 1 |
||||||
|
} |
||||||
|
if isEnd { |
||||||
|
break |
||||||
|
} |
||||||
|
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { |
||||||
|
inQuote = !inQuote |
||||||
|
} |
||||||
|
} |
||||||
|
for i < len(in) && (in[i] == ' ' || in[i] == '\t') { |
||||||
|
i++ |
||||||
|
} |
||||||
|
if i == len(in) { |
||||||
|
// Invalid line: unmatched quote
|
||||||
|
in = rest |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
in = in[i:] |
||||||
|
i = bytes.IndexAny(in, " \t") |
||||||
|
if i == -1 { |
||||||
|
in = rest |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||||
|
options = candidateOptions |
||||||
|
return out, comment, options, rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
in = rest |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
return nil, "", nil, nil, errors.New("ssh: no key found") |
||||||
|
} |
||||||
|
|
||||||
|
// ParsePublicKey parses an SSH public key formatted for use in
|
||||||
|
// the SSH wire protocol according to RFC 4253, section 6.6.
|
||||||
|
func ParsePublicKey(in []byte) (out PublicKey, err error) { |
||||||
|
algo, in, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return nil, errShortRead |
||||||
|
} |
||||||
|
var rest []byte |
||||||
|
out, rest, err = parsePubKey(in, string(algo)) |
||||||
|
if len(rest) > 0 { |
||||||
|
return nil, errors.New("ssh: trailing junk in public key") |
||||||
|
} |
||||||
|
|
||||||
|
return out, err |
||||||
|
} |
||||||
|
|
||||||
|
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
|
||||||
|
// authorized_keys file. The return value ends with newline.
|
||||||
|
func MarshalAuthorizedKey(key PublicKey) []byte { |
||||||
|
b := &bytes.Buffer{} |
||||||
|
b.WriteString(key.Type()) |
||||||
|
b.WriteByte(' ') |
||||||
|
e := base64.NewEncoder(base64.StdEncoding, b) |
||||||
|
e.Write(key.Marshal()) |
||||||
|
e.Close() |
||||||
|
b.WriteByte('\n') |
||||||
|
return b.Bytes() |
||||||
|
} |
||||||
|
|
||||||
|
// PublicKey is an abstraction of different types of public keys.
|
||||||
|
type PublicKey interface { |
||||||
|
// Type returns the key's type, e.g. "ssh-rsa".
|
||||||
|
Type() string |
||||||
|
|
||||||
|
// Marshal returns the serialized key data in SSH wire format,
|
||||||
|
// with the name prefix.
|
||||||
|
Marshal() []byte |
||||||
|
|
||||||
|
// Verify that sig is a signature on the given data using this
|
||||||
|
// key. This function will hash the data appropriately first.
|
||||||
|
Verify(data []byte, sig *Signature) error |
||||||
|
} |
||||||
|
|
||||||
|
// A Signer can create signatures that verify against a public key.
|
||||||
|
type Signer interface { |
||||||
|
// PublicKey returns an associated PublicKey instance.
|
||||||
|
PublicKey() PublicKey |
||||||
|
|
||||||
|
// Sign returns raw signature for the given data. This method
|
||||||
|
// will apply the hash specified for the keytype to the data.
|
||||||
|
Sign(rand io.Reader, data []byte) (*Signature, error) |
||||||
|
} |
||||||
|
|
||||||
|
type rsaPublicKey rsa.PublicKey |
||||||
|
|
||||||
|
func (r *rsaPublicKey) Type() string { |
||||||
|
return "ssh-rsa" |
||||||
|
} |
||||||
|
|
||||||
|
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
|
||||||
|
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||||
|
var w struct { |
||||||
|
E *big.Int |
||||||
|
N *big.Int |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
if err := Unmarshal(in, &w); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if w.E.BitLen() > 24 { |
||||||
|
return nil, nil, errors.New("ssh: exponent too large") |
||||||
|
} |
||||||
|
e := w.E.Int64() |
||||||
|
if e < 3 || e&1 == 0 { |
||||||
|
return nil, nil, errors.New("ssh: incorrect exponent") |
||||||
|
} |
||||||
|
|
||||||
|
var key rsa.PublicKey |
||||||
|
key.E = int(e) |
||||||
|
key.N = w.N |
||||||
|
return (*rsaPublicKey)(&key), w.Rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (r *rsaPublicKey) Marshal() []byte { |
||||||
|
e := new(big.Int).SetInt64(int64(r.E)) |
||||||
|
wirekey := struct { |
||||||
|
Name string |
||||||
|
E *big.Int |
||||||
|
N *big.Int |
||||||
|
}{ |
||||||
|
KeyAlgoRSA, |
||||||
|
e, |
||||||
|
r.N, |
||||||
|
} |
||||||
|
return Marshal(&wirekey) |
||||||
|
} |
||||||
|
|
||||||
|
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||||
|
if sig.Format != r.Type() { |
||||||
|
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) |
||||||
|
} |
||||||
|
h := crypto.SHA1.New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) |
||||||
|
} |
||||||
|
|
||||||
|
type rsaPrivateKey struct { |
||||||
|
*rsa.PrivateKey |
||||||
|
} |
||||||
|
|
||||||
|
func (r *rsaPrivateKey) PublicKey() PublicKey { |
||||||
|
return (*rsaPublicKey)(&r.PrivateKey.PublicKey) |
||||||
|
} |
||||||
|
|
||||||
|
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||||
|
h := crypto.SHA1.New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &Signature{ |
||||||
|
Format: r.PublicKey().Type(), |
||||||
|
Blob: blob, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
type dsaPublicKey dsa.PublicKey |
||||||
|
|
||||||
|
func (r *dsaPublicKey) Type() string { |
||||||
|
return "ssh-dss" |
||||||
|
} |
||||||
|
|
||||||
|
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
|
||||||
|
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||||
|
var w struct { |
||||||
|
P, Q, G, Y *big.Int |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
if err := Unmarshal(in, &w); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
key := &dsaPublicKey{ |
||||||
|
Parameters: dsa.Parameters{ |
||||||
|
P: w.P, |
||||||
|
Q: w.Q, |
||||||
|
G: w.G, |
||||||
|
}, |
||||||
|
Y: w.Y, |
||||||
|
} |
||||||
|
return key, w.Rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (k *dsaPublicKey) Marshal() []byte { |
||||||
|
w := struct { |
||||||
|
Name string |
||||||
|
P, Q, G, Y *big.Int |
||||||
|
}{ |
||||||
|
k.Type(), |
||||||
|
k.P, |
||||||
|
k.Q, |
||||||
|
k.G, |
||||||
|
k.Y, |
||||||
|
} |
||||||
|
|
||||||
|
return Marshal(&w) |
||||||
|
} |
||||||
|
|
||||||
|
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||||
|
if sig.Format != k.Type() { |
||||||
|
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) |
||||||
|
} |
||||||
|
h := crypto.SHA1.New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
|
||||||
|
// Per RFC 4253, section 6.6,
|
||||||
|
// The value for 'dss_signature_blob' is encoded as a string containing
|
||||||
|
// r, followed by s (which are 160-bit integers, without lengths or
|
||||||
|
// padding, unsigned, and in network byte order).
|
||||||
|
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
|
||||||
|
if len(sig.Blob) != 40 { |
||||||
|
return errors.New("ssh: DSA signature parse error") |
||||||
|
} |
||||||
|
r := new(big.Int).SetBytes(sig.Blob[:20]) |
||||||
|
s := new(big.Int).SetBytes(sig.Blob[20:]) |
||||||
|
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return errors.New("ssh: signature did not verify") |
||||||
|
} |
||||||
|
|
||||||
|
type dsaPrivateKey struct { |
||||||
|
*dsa.PrivateKey |
||||||
|
} |
||||||
|
|
||||||
|
func (k *dsaPrivateKey) PublicKey() PublicKey { |
||||||
|
return (*dsaPublicKey)(&k.PrivateKey.PublicKey) |
||||||
|
} |
||||||
|
|
||||||
|
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||||
|
h := crypto.SHA1.New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
r, s, err := dsa.Sign(rand, k.PrivateKey, digest) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
sig := make([]byte, 40) |
||||||
|
rb := r.Bytes() |
||||||
|
sb := s.Bytes() |
||||||
|
|
||||||
|
copy(sig[20-len(rb):20], rb) |
||||||
|
copy(sig[40-len(sb):], sb) |
||||||
|
|
||||||
|
return &Signature{ |
||||||
|
Format: k.PublicKey().Type(), |
||||||
|
Blob: sig, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
type ecdsaPublicKey ecdsa.PublicKey |
||||||
|
|
||||||
|
func (key *ecdsaPublicKey) Type() string { |
||||||
|
return "ecdsa-sha2-" + key.nistID() |
||||||
|
} |
||||||
|
|
||||||
|
func (key *ecdsaPublicKey) nistID() string { |
||||||
|
switch key.Params().BitSize { |
||||||
|
case 256: |
||||||
|
return "nistp256" |
||||||
|
case 384: |
||||||
|
return "nistp384" |
||||||
|
case 521: |
||||||
|
return "nistp521" |
||||||
|
} |
||||||
|
panic("ssh: unsupported ecdsa key size") |
||||||
|
} |
||||||
|
|
||||||
|
func supportedEllipticCurve(curve elliptic.Curve) bool { |
||||||
|
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() |
||||||
|
} |
||||||
|
|
||||||
|
// ecHash returns the hash to match the given elliptic curve, see RFC
|
||||||
|
// 5656, section 6.2.1
|
||||||
|
func ecHash(curve elliptic.Curve) crypto.Hash { |
||||||
|
bitSize := curve.Params().BitSize |
||||||
|
switch { |
||||||
|
case bitSize <= 256: |
||||||
|
return crypto.SHA256 |
||||||
|
case bitSize <= 384: |
||||||
|
return crypto.SHA384 |
||||||
|
} |
||||||
|
return crypto.SHA512 |
||||||
|
} |
||||||
|
|
||||||
|
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
|
||||||
|
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||||
|
var w struct { |
||||||
|
Curve string |
||||||
|
KeyBytes []byte |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
if err := Unmarshal(in, &w); err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
key := new(ecdsa.PublicKey) |
||||||
|
|
||||||
|
switch w.Curve { |
||||||
|
case "nistp256": |
||||||
|
key.Curve = elliptic.P256() |
||||||
|
case "nistp384": |
||||||
|
key.Curve = elliptic.P384() |
||||||
|
case "nistp521": |
||||||
|
key.Curve = elliptic.P521() |
||||||
|
default: |
||||||
|
return nil, nil, errors.New("ssh: unsupported curve") |
||||||
|
} |
||||||
|
|
||||||
|
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) |
||||||
|
if key.X == nil || key.Y == nil { |
||||||
|
return nil, nil, errors.New("ssh: invalid curve point") |
||||||
|
} |
||||||
|
return (*ecdsaPublicKey)(key), w.Rest, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (key *ecdsaPublicKey) Marshal() []byte { |
||||||
|
// See RFC 5656, section 3.1.
|
||||||
|
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) |
||||||
|
w := struct { |
||||||
|
Name string |
||||||
|
ID string |
||||||
|
Key []byte |
||||||
|
}{ |
||||||
|
key.Type(), |
||||||
|
key.nistID(), |
||||||
|
keyBytes, |
||||||
|
} |
||||||
|
|
||||||
|
return Marshal(&w) |
||||||
|
} |
||||||
|
|
||||||
|
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||||
|
if sig.Format != key.Type() { |
||||||
|
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) |
||||||
|
} |
||||||
|
|
||||||
|
h := ecHash(key.Curve).New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
|
||||||
|
// Per RFC 5656, section 3.1.2,
|
||||||
|
// The ecdsa_signature_blob value has the following specific encoding:
|
||||||
|
// mpint r
|
||||||
|
// mpint s
|
||||||
|
var ecSig struct { |
||||||
|
R *big.Int |
||||||
|
S *big.Int |
||||||
|
} |
||||||
|
|
||||||
|
if err := Unmarshal(sig.Blob, &ecSig); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return errors.New("ssh: signature did not verify") |
||||||
|
} |
||||||
|
|
||||||
|
type ecdsaPrivateKey struct { |
||||||
|
*ecdsa.PrivateKey |
||||||
|
} |
||||||
|
|
||||||
|
func (k *ecdsaPrivateKey) PublicKey() PublicKey { |
||||||
|
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) |
||||||
|
} |
||||||
|
|
||||||
|
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||||
|
h := ecHash(k.PrivateKey.PublicKey.Curve).New() |
||||||
|
h.Write(data) |
||||||
|
digest := h.Sum(nil) |
||||||
|
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
sig := make([]byte, intLength(r)+intLength(s)) |
||||||
|
rest := marshalInt(sig, r) |
||||||
|
marshalInt(rest, s) |
||||||
|
return &Signature{ |
||||||
|
Format: k.PublicKey().Type(), |
||||||
|
Blob: sig, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
|
||||||
|
// returns a corresponding Signer instance. EC keys should use P256,
|
||||||
|
// P384 or P521.
|
||||||
|
func NewSignerFromKey(k interface{}) (Signer, error) { |
||||||
|
var sshKey Signer |
||||||
|
switch t := k.(type) { |
||||||
|
case *rsa.PrivateKey: |
||||||
|
sshKey = &rsaPrivateKey{t} |
||||||
|
case *dsa.PrivateKey: |
||||||
|
sshKey = &dsaPrivateKey{t} |
||||||
|
case *ecdsa.PrivateKey: |
||||||
|
if !supportedEllipticCurve(t.Curve) { |
||||||
|
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||||
|
} |
||||||
|
|
||||||
|
sshKey = &ecdsaPrivateKey{t} |
||||||
|
default: |
||||||
|
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||||
|
} |
||||||
|
return sshKey, nil |
||||||
|
} |
||||||
|
|
||||||
|
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
|
||||||
|
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
|
||||||
|
func NewPublicKey(k interface{}) (PublicKey, error) { |
||||||
|
var sshKey PublicKey |
||||||
|
switch t := k.(type) { |
||||||
|
case *rsa.PublicKey: |
||||||
|
sshKey = (*rsaPublicKey)(t) |
||||||
|
case *ecdsa.PublicKey: |
||||||
|
if !supportedEllipticCurve(t.Curve) { |
||||||
|
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||||
|
} |
||||||
|
sshKey = (*ecdsaPublicKey)(t) |
||||||
|
case *dsa.PublicKey: |
||||||
|
sshKey = (*dsaPublicKey)(t) |
||||||
|
default: |
||||||
|
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||||
|
} |
||||||
|
return sshKey, nil |
||||||
|
} |
||||||
|
|
||||||
|
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
|
||||||
|
// the same keys as ParseRawPrivateKey.
|
||||||
|
func ParsePrivateKey(pemBytes []byte) (Signer, error) { |
||||||
|
key, err := ParseRawPrivateKey(pemBytes) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return NewSignerFromKey(key) |
||||||
|
} |
||||||
|
|
||||||
|
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
|
||||||
|
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
|
||||||
|
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { |
||||||
|
block, _ := pem.Decode(pemBytes) |
||||||
|
if block == nil { |
||||||
|
return nil, errors.New("ssh: no key found") |
||||||
|
} |
||||||
|
|
||||||
|
switch block.Type { |
||||||
|
case "RSA PRIVATE KEY": |
||||||
|
return x509.ParsePKCS1PrivateKey(block.Bytes) |
||||||
|
case "EC PRIVATE KEY": |
||||||
|
return x509.ParseECPrivateKey(block.Bytes) |
||||||
|
case "DSA PRIVATE KEY": |
||||||
|
return ParseDSAPrivateKey(block.Bytes) |
||||||
|
default: |
||||||
|
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
|
||||||
|
// specified by the OpenSSL DSA man page.
|
||||||
|
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { |
||||||
|
var k struct { |
||||||
|
Version int |
||||||
|
P *big.Int |
||||||
|
Q *big.Int |
||||||
|
G *big.Int |
||||||
|
Priv *big.Int |
||||||
|
Pub *big.Int |
||||||
|
} |
||||||
|
rest, err := asn1.Unmarshal(der, &k) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) |
||||||
|
} |
||||||
|
if len(rest) > 0 { |
||||||
|
return nil, errors.New("ssh: garbage after DSA key") |
||||||
|
} |
||||||
|
|
||||||
|
return &dsa.PrivateKey{ |
||||||
|
PublicKey: dsa.PublicKey{ |
||||||
|
Parameters: dsa.Parameters{ |
||||||
|
P: k.P, |
||||||
|
Q: k.Q, |
||||||
|
G: k.G, |
||||||
|
}, |
||||||
|
Y: k.Priv, |
||||||
|
}, |
||||||
|
X: k.Pub, |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,306 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/dsa" |
||||||
|
"crypto/ecdsa" |
||||||
|
"crypto/elliptic" |
||||||
|
"crypto/rand" |
||||||
|
"crypto/rsa" |
||||||
|
"encoding/base64" |
||||||
|
"fmt" |
||||||
|
"reflect" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/ssh/testdata" |
||||||
|
) |
||||||
|
|
||||||
|
func rawKey(pub PublicKey) interface{} { |
||||||
|
switch k := pub.(type) { |
||||||
|
case *rsaPublicKey: |
||||||
|
return (*rsa.PublicKey)(k) |
||||||
|
case *dsaPublicKey: |
||||||
|
return (*dsa.PublicKey)(k) |
||||||
|
case *ecdsaPublicKey: |
||||||
|
return (*ecdsa.PublicKey)(k) |
||||||
|
case *Certificate: |
||||||
|
return k |
||||||
|
} |
||||||
|
panic("unknown key type") |
||||||
|
} |
||||||
|
|
||||||
|
func TestKeyMarshalParse(t *testing.T) { |
||||||
|
for _, priv := range testSigners { |
||||||
|
pub := priv.PublicKey() |
||||||
|
roundtrip, err := ParsePublicKey(pub.Marshal()) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("ParsePublicKey(%T): %v", pub, err) |
||||||
|
} |
||||||
|
|
||||||
|
k1 := rawKey(pub) |
||||||
|
k2 := rawKey(roundtrip) |
||||||
|
|
||||||
|
if !reflect.DeepEqual(k1, k2) { |
||||||
|
t.Errorf("got %#v in roundtrip, want %#v", k2, k1) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestUnsupportedCurves(t *testing.T) { |
||||||
|
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("GenerateKey: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||||
|
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||||
|
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestNewPublicKey(t *testing.T) { |
||||||
|
for _, k := range testSigners { |
||||||
|
raw := rawKey(k.PublicKey()) |
||||||
|
// Skip certificates, as NewPublicKey does not support them.
|
||||||
|
if _, ok := raw.(*Certificate); ok { |
||||||
|
continue |
||||||
|
} |
||||||
|
pub, err := NewPublicKey(raw) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("NewPublicKey(%#v): %v", raw, err) |
||||||
|
} |
||||||
|
if !reflect.DeepEqual(k.PublicKey(), pub) { |
||||||
|
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestKeySignVerify(t *testing.T) { |
||||||
|
for _, priv := range testSigners { |
||||||
|
pub := priv.PublicKey() |
||||||
|
|
||||||
|
data := []byte("sign me") |
||||||
|
sig, err := priv.Sign(rand.Reader, data) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Sign(%T): %v", priv, err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := pub.Verify(data, sig); err != nil { |
||||||
|
t.Errorf("publicKey.Verify(%T): %v", priv, err) |
||||||
|
} |
||||||
|
sig.Blob[5]++ |
||||||
|
if err := pub.Verify(data, sig); err == nil { |
||||||
|
t.Errorf("publicKey.Verify on broken sig did not fail") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestParseRSAPrivateKey(t *testing.T) { |
||||||
|
key := testPrivateKeys["rsa"] |
||||||
|
|
||||||
|
rsa, ok := key.(*rsa.PrivateKey) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("got %T, want *rsa.PrivateKey", rsa) |
||||||
|
} |
||||||
|
|
||||||
|
if err := rsa.Validate(); err != nil { |
||||||
|
t.Errorf("Validate: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestParseECPrivateKey(t *testing.T) { |
||||||
|
key := testPrivateKeys["ecdsa"] |
||||||
|
|
||||||
|
ecKey, ok := key.(*ecdsa.PrivateKey) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) |
||||||
|
} |
||||||
|
|
||||||
|
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { |
||||||
|
t.Fatalf("public key does not validate.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestParseDSA(t *testing.T) { |
||||||
|
// We actually exercise the ParsePrivateKey codepath here, as opposed to
|
||||||
|
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
|
||||||
|
// uses.
|
||||||
|
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("ParsePrivateKey returned error: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
data := []byte("sign me") |
||||||
|
sig, err := s.Sign(rand.Reader, data) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("dsa.Sign: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := s.PublicKey().Verify(data, sig); err != nil { |
||||||
|
t.Errorf("Verify failed: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Tests for authorized_keys parsing.
|
||||||
|
|
||||||
|
// getTestKey returns a public key, and its base64 encoding.
|
||||||
|
func getTestKey() (PublicKey, string) { |
||||||
|
k := testPublicKeys["rsa"] |
||||||
|
|
||||||
|
b := &bytes.Buffer{} |
||||||
|
e := base64.NewEncoder(base64.StdEncoding, b) |
||||||
|
e.Write(k.Marshal()) |
||||||
|
e.Close() |
||||||
|
|
||||||
|
return k, b.String() |
||||||
|
} |
||||||
|
|
||||||
|
func TestMarshalParsePublicKey(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) |
||||||
|
|
||||||
|
authKeys := MarshalAuthorizedKey(pub) |
||||||
|
actualFields := strings.Fields(string(authKeys)) |
||||||
|
if len(actualFields) == 0 { |
||||||
|
t.Fatalf("failed authKeys: %v", authKeys) |
||||||
|
} |
||||||
|
|
||||||
|
// drop the comment
|
||||||
|
expectedFields := strings.Fields(line)[0:2] |
||||||
|
|
||||||
|
if !reflect.DeepEqual(actualFields, expectedFields) { |
||||||
|
t.Errorf("got %v, expected %v", actualFields, expectedFields) |
||||||
|
} |
||||||
|
|
||||||
|
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("cannot parse %v: %v", line, err) |
||||||
|
} |
||||||
|
if !reflect.DeepEqual(actPub, pub) { |
||||||
|
t.Errorf("got %v, expected %v", actPub, pub) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type authResult struct { |
||||||
|
pubKey PublicKey |
||||||
|
options []string |
||||||
|
comments string |
||||||
|
rest string |
||||||
|
ok bool |
||||||
|
} |
||||||
|
|
||||||
|
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { |
||||||
|
rest := authKeys |
||||||
|
var values []authResult |
||||||
|
for len(rest) > 0 { |
||||||
|
var r authResult |
||||||
|
var err error |
||||||
|
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) |
||||||
|
r.ok = (err == nil) |
||||||
|
t.Log(err) |
||||||
|
r.rest = string(rest) |
||||||
|
values = append(values, r) |
||||||
|
} |
||||||
|
|
||||||
|
if !reflect.DeepEqual(values, expected) { |
||||||
|
t.Errorf("got %#v, expected %#v", values, expected) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthorizedKeyBasic(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
line := "ssh-rsa " + pubSerialized + " user@host" |
||||||
|
testAuthorizedKeys(t, []byte(line), |
||||||
|
[]authResult{ |
||||||
|
{pub, nil, "user@host", "", true}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuth(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
authWithOptions := []string{ |
||||||
|
`# comments to ignore before any keys...`, |
||||||
|
``, |
||||||
|
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, |
||||||
|
`# comments to ignore, along with a blank line`, |
||||||
|
``, |
||||||
|
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, |
||||||
|
``, |
||||||
|
`# more comments, plus a invalid entry`, |
||||||
|
`ssh-rsa data-that-will-not-parse user@host3`, |
||||||
|
} |
||||||
|
for _, eol := range []string{"\n", "\r\n"} { |
||||||
|
authOptions := strings.Join(authWithOptions, eol) |
||||||
|
rest2 := strings.Join(authWithOptions[3:], eol) |
||||||
|
rest3 := strings.Join(authWithOptions[6:], eol) |
||||||
|
testAuthorizedKeys(t, []byte(authOptions), []authResult{ |
||||||
|
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, |
||||||
|
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, |
||||||
|
{nil, nil, "", "", false}, |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthWithQuotedSpaceInEnv(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||||
|
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ |
||||||
|
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthWithQuotedCommaInEnv(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||||
|
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ |
||||||
|
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthWithQuotedQuoteInEnv(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) |
||||||
|
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) |
||||||
|
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ |
||||||
|
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||||
|
}) |
||||||
|
|
||||||
|
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ |
||||||
|
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthWithInvalidSpace(t *testing.T) { |
||||||
|
_, pubSerialized := getTestKey() |
||||||
|
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||||
|
#more to follow but still no valid keys`) |
||||||
|
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ |
||||||
|
{nil, nil, "", "", false}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthWithMissingQuote(t *testing.T) { |
||||||
|
pub, pubSerialized := getTestKey() |
||||||
|
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||||
|
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) |
||||||
|
|
||||||
|
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ |
||||||
|
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TestInvalidEntry(t *testing.T) { |
||||||
|
authInvalid := []byte(`ssh-rsa`) |
||||||
|
_, _, _, _, err := ParseAuthorizedKey(authInvalid) |
||||||
|
if err == nil { |
||||||
|
t.Errorf("got valid entry for %q", authInvalid) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,57 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
// Message authentication support
|
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/hmac" |
||||||
|
"crypto/sha1" |
||||||
|
"crypto/sha256" |
||||||
|
"hash" |
||||||
|
) |
||||||
|
|
||||||
|
type macMode struct { |
||||||
|
keySize int |
||||||
|
new func(key []byte) hash.Hash |
||||||
|
} |
||||||
|
|
||||||
|
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||||
|
// a given size.
|
||||||
|
type truncatingMAC struct { |
||||||
|
length int |
||||||
|
hmac hash.Hash |
||||||
|
} |
||||||
|
|
||||||
|
func (t truncatingMAC) Write(data []byte) (int, error) { |
||||||
|
return t.hmac.Write(data) |
||||||
|
} |
||||||
|
|
||||||
|
func (t truncatingMAC) Sum(in []byte) []byte { |
||||||
|
out := t.hmac.Sum(in) |
||||||
|
return out[:len(in)+t.length] |
||||||
|
} |
||||||
|
|
||||||
|
func (t truncatingMAC) Reset() { |
||||||
|
t.hmac.Reset() |
||||||
|
} |
||||||
|
|
||||||
|
func (t truncatingMAC) Size() int { |
||||||
|
return t.length |
||||||
|
} |
||||||
|
|
||||||
|
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } |
||||||
|
|
||||||
|
var macModes = map[string]*macMode{ |
||||||
|
"hmac-sha2-256": {32, func(key []byte) hash.Hash { |
||||||
|
return hmac.New(sha256.New, key) |
||||||
|
}}, |
||||||
|
"hmac-sha1": {20, func(key []byte) hash.Hash { |
||||||
|
return hmac.New(sha1.New, key) |
||||||
|
}}, |
||||||
|
"hmac-sha1-96": {20, func(key []byte) hash.Hash { |
||||||
|
return truncatingMAC{12, hmac.New(sha1.New, key)} |
||||||
|
}}, |
||||||
|
} |
@ -0,0 +1,110 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"sync" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
// An in-memory packetConn. It is safe to call Close and writePacket
|
||||||
|
// from different goroutines.
|
||||||
|
type memTransport struct { |
||||||
|
eof bool |
||||||
|
pending [][]byte |
||||||
|
write *memTransport |
||||||
|
sync.Mutex |
||||||
|
*sync.Cond |
||||||
|
} |
||||||
|
|
||||||
|
func (t *memTransport) readPacket() ([]byte, error) { |
||||||
|
t.Lock() |
||||||
|
defer t.Unlock() |
||||||
|
for { |
||||||
|
if len(t.pending) > 0 { |
||||||
|
r := t.pending[0] |
||||||
|
t.pending = t.pending[1:] |
||||||
|
return r, nil |
||||||
|
} |
||||||
|
if t.eof { |
||||||
|
return nil, io.EOF |
||||||
|
} |
||||||
|
t.Cond.Wait() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (t *memTransport) closeSelf() error { |
||||||
|
t.Lock() |
||||||
|
defer t.Unlock() |
||||||
|
if t.eof { |
||||||
|
return io.EOF |
||||||
|
} |
||||||
|
t.eof = true |
||||||
|
t.Cond.Broadcast() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (t *memTransport) Close() error { |
||||||
|
err := t.write.closeSelf() |
||||||
|
t.closeSelf() |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *memTransport) writePacket(p []byte) error { |
||||||
|
t.write.Lock() |
||||||
|
defer t.write.Unlock() |
||||||
|
if t.write.eof { |
||||||
|
return io.EOF |
||||||
|
} |
||||||
|
c := make([]byte, len(p)) |
||||||
|
copy(c, p) |
||||||
|
t.write.pending = append(t.write.pending, c) |
||||||
|
t.write.Cond.Signal() |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func memPipe() (a, b packetConn) { |
||||||
|
t1 := memTransport{} |
||||||
|
t2 := memTransport{} |
||||||
|
t1.write = &t2 |
||||||
|
t2.write = &t1 |
||||||
|
t1.Cond = sync.NewCond(&t1.Mutex) |
||||||
|
t2.Cond = sync.NewCond(&t2.Mutex) |
||||||
|
return &t1, &t2 |
||||||
|
} |
||||||
|
|
||||||
|
func TestMemPipe(t *testing.T) { |
||||||
|
a, b := memPipe() |
||||||
|
if err := a.writePacket([]byte{42}); err != nil { |
||||||
|
t.Fatalf("writePacket: %v", err) |
||||||
|
} |
||||||
|
if err := a.Close(); err != nil { |
||||||
|
t.Fatal("Close: ", err) |
||||||
|
} |
||||||
|
p, err := b.readPacket() |
||||||
|
if err != nil { |
||||||
|
t.Fatal("readPacket: ", err) |
||||||
|
} |
||||||
|
if len(p) != 1 || p[0] != 42 { |
||||||
|
t.Fatalf("got %v, want {42}", p) |
||||||
|
} |
||||||
|
p, err = b.readPacket() |
||||||
|
if err != io.EOF { |
||||||
|
t.Fatalf("got %v, %v, want EOF", p, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestDoubleClose(t *testing.T) { |
||||||
|
a, _ := memPipe() |
||||||
|
err := a.Close() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("Close: %v", err) |
||||||
|
} |
||||||
|
err = a.Close() |
||||||
|
if err != io.EOF { |
||||||
|
t.Errorf("expect EOF on double close.") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,725 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"math/big" |
||||||
|
"reflect" |
||||||
|
"strconv" |
||||||
|
) |
||||||
|
|
||||||
|
// These are SSH message type numbers. They are scattered around several
|
||||||
|
// documents but many were taken from [SSH-PARAMETERS].
|
||||||
|
const ( |
||||||
|
msgIgnore = 2 |
||||||
|
msgUnimplemented = 3 |
||||||
|
msgDebug = 4 |
||||||
|
msgNewKeys = 21 |
||||||
|
|
||||||
|
// Standard authentication messages
|
||||||
|
msgUserAuthSuccess = 52 |
||||||
|
msgUserAuthBanner = 53 |
||||||
|
) |
||||||
|
|
||||||
|
// SSH messages:
|
||||||
|
//
|
||||||
|
// These structures mirror the wire format of the corresponding SSH messages.
|
||||||
|
// They are marshaled using reflection with the marshal and unmarshal functions
|
||||||
|
// in this file. The only wrinkle is that a final member of type []byte with a
|
||||||
|
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
||||||
|
|
||||||
|
// See RFC 4253, section 11.1.
|
||||||
|
const msgDisconnect = 1 |
||||||
|
|
||||||
|
// disconnectMsg is the message that signals a disconnect. It is also
|
||||||
|
// the error type returned from mux.Wait()
|
||||||
|
type disconnectMsg struct { |
||||||
|
Reason uint32 `sshtype:"1"` |
||||||
|
Message string |
||||||
|
Language string |
||||||
|
} |
||||||
|
|
||||||
|
func (d *disconnectMsg) Error() string { |
||||||
|
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4253, section 7.1.
|
||||||
|
const msgKexInit = 20 |
||||||
|
|
||||||
|
type kexInitMsg struct { |
||||||
|
Cookie [16]byte `sshtype:"20"` |
||||||
|
KexAlgos []string |
||||||
|
ServerHostKeyAlgos []string |
||||||
|
CiphersClientServer []string |
||||||
|
CiphersServerClient []string |
||||||
|
MACsClientServer []string |
||||||
|
MACsServerClient []string |
||||||
|
CompressionClientServer []string |
||||||
|
CompressionServerClient []string |
||||||
|
LanguagesClientServer []string |
||||||
|
LanguagesServerClient []string |
||||||
|
FirstKexFollows bool |
||||||
|
Reserved uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4253, section 8.
|
||||||
|
|
||||||
|
// Diffie-Helman
|
||||||
|
const msgKexDHInit = 30 |
||||||
|
|
||||||
|
type kexDHInitMsg struct { |
||||||
|
X *big.Int `sshtype:"30"` |
||||||
|
} |
||||||
|
|
||||||
|
const msgKexECDHInit = 30 |
||||||
|
|
||||||
|
type kexECDHInitMsg struct { |
||||||
|
ClientPubKey []byte `sshtype:"30"` |
||||||
|
} |
||||||
|
|
||||||
|
const msgKexECDHReply = 31 |
||||||
|
|
||||||
|
type kexECDHReplyMsg struct { |
||||||
|
HostKey []byte `sshtype:"31"` |
||||||
|
EphemeralPubKey []byte |
||||||
|
Signature []byte |
||||||
|
} |
||||||
|
|
||||||
|
const msgKexDHReply = 31 |
||||||
|
|
||||||
|
type kexDHReplyMsg struct { |
||||||
|
HostKey []byte `sshtype:"31"` |
||||||
|
Y *big.Int |
||||||
|
Signature []byte |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4253, section 10.
|
||||||
|
const msgServiceRequest = 5 |
||||||
|
|
||||||
|
type serviceRequestMsg struct { |
||||||
|
Service string `sshtype:"5"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4253, section 10.
|
||||||
|
const msgServiceAccept = 6 |
||||||
|
|
||||||
|
type serviceAcceptMsg struct { |
||||||
|
Service string `sshtype:"6"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4252, section 5.
|
||||||
|
const msgUserAuthRequest = 50 |
||||||
|
|
||||||
|
type userAuthRequestMsg struct { |
||||||
|
User string `sshtype:"50"` |
||||||
|
Service string |
||||||
|
Method string |
||||||
|
Payload []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4252, section 5.1
|
||||||
|
const msgUserAuthFailure = 51 |
||||||
|
|
||||||
|
type userAuthFailureMsg struct { |
||||||
|
Methods []string `sshtype:"51"` |
||||||
|
PartialSuccess bool |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4256, section 3.2
|
||||||
|
const msgUserAuthInfoRequest = 60 |
||||||
|
const msgUserAuthInfoResponse = 61 |
||||||
|
|
||||||
|
type userAuthInfoRequestMsg struct { |
||||||
|
User string `sshtype:"60"` |
||||||
|
Instruction string |
||||||
|
DeprecatedLanguage string |
||||||
|
NumPrompts uint32 |
||||||
|
Prompts []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.1.
|
||||||
|
const msgChannelOpen = 90 |
||||||
|
|
||||||
|
type channelOpenMsg struct { |
||||||
|
ChanType string `sshtype:"90"` |
||||||
|
PeersId uint32 |
||||||
|
PeersWindow uint32 |
||||||
|
MaxPacketSize uint32 |
||||||
|
TypeSpecificData []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
const msgChannelExtendedData = 95 |
||||||
|
const msgChannelData = 94 |
||||||
|
|
||||||
|
// See RFC 4254, section 5.1.
|
||||||
|
const msgChannelOpenConfirm = 91 |
||||||
|
|
||||||
|
type channelOpenConfirmMsg struct { |
||||||
|
PeersId uint32 `sshtype:"91"` |
||||||
|
MyId uint32 |
||||||
|
MyWindow uint32 |
||||||
|
MaxPacketSize uint32 |
||||||
|
TypeSpecificData []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.1.
|
||||||
|
const msgChannelOpenFailure = 92 |
||||||
|
|
||||||
|
type channelOpenFailureMsg struct { |
||||||
|
PeersId uint32 `sshtype:"92"` |
||||||
|
Reason RejectionReason |
||||||
|
Message string |
||||||
|
Language string |
||||||
|
} |
||||||
|
|
||||||
|
const msgChannelRequest = 98 |
||||||
|
|
||||||
|
type channelRequestMsg struct { |
||||||
|
PeersId uint32 `sshtype:"98"` |
||||||
|
Request string |
||||||
|
WantReply bool |
||||||
|
RequestSpecificData []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.4.
|
||||||
|
const msgChannelSuccess = 99 |
||||||
|
|
||||||
|
type channelRequestSuccessMsg struct { |
||||||
|
PeersId uint32 `sshtype:"99"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.4.
|
||||||
|
const msgChannelFailure = 100 |
||||||
|
|
||||||
|
type channelRequestFailureMsg struct { |
||||||
|
PeersId uint32 `sshtype:"100"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.3
|
||||||
|
const msgChannelClose = 97 |
||||||
|
|
||||||
|
type channelCloseMsg struct { |
||||||
|
PeersId uint32 `sshtype:"97"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.3
|
||||||
|
const msgChannelEOF = 96 |
||||||
|
|
||||||
|
type channelEOFMsg struct { |
||||||
|
PeersId uint32 `sshtype:"96"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 4
|
||||||
|
const msgGlobalRequest = 80 |
||||||
|
|
||||||
|
type globalRequestMsg struct { |
||||||
|
Type string `sshtype:"80"` |
||||||
|
WantReply bool |
||||||
|
Data []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 4
|
||||||
|
const msgRequestSuccess = 81 |
||||||
|
|
||||||
|
type globalRequestSuccessMsg struct { |
||||||
|
Data []byte `ssh:"rest" sshtype:"81"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 4
|
||||||
|
const msgRequestFailure = 82 |
||||||
|
|
||||||
|
type globalRequestFailureMsg struct { |
||||||
|
Data []byte `ssh:"rest" sshtype:"82"` |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 5.2
|
||||||
|
const msgChannelWindowAdjust = 93 |
||||||
|
|
||||||
|
type windowAdjustMsg struct { |
||||||
|
PeersId uint32 `sshtype:"93"` |
||||||
|
AdditionalBytes uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4252, section 7
|
||||||
|
const msgUserAuthPubKeyOk = 60 |
||||||
|
|
||||||
|
type userAuthPubKeyOkMsg struct { |
||||||
|
Algo string `sshtype:"60"` |
||||||
|
PubKey []byte |
||||||
|
} |
||||||
|
|
||||||
|
// typeTag returns the type byte for the given type. The type should
|
||||||
|
// be struct.
|
||||||
|
func typeTag(structType reflect.Type) byte { |
||||||
|
var tag byte |
||||||
|
var tagStr string |
||||||
|
tagStr = structType.Field(0).Tag.Get("sshtype") |
||||||
|
i, err := strconv.Atoi(tagStr) |
||||||
|
if err == nil { |
||||||
|
tag = byte(i) |
||||||
|
} |
||||||
|
return tag |
||||||
|
} |
||||||
|
|
||||||
|
func fieldError(t reflect.Type, field int, problem string) error { |
||||||
|
if problem != "" { |
||||||
|
problem = ": " + problem |
||||||
|
} |
||||||
|
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) |
||||||
|
} |
||||||
|
|
||||||
|
var errShortRead = errors.New("ssh: short read") |
||||||
|
|
||||||
|
// Unmarshal parses data in SSH wire format into a structure. The out
|
||||||
|
// argument should be a pointer to struct. If the first member of the
|
||||||
|
// struct has the "sshtype" tag set to a number in decimal, the packet
|
||||||
|
// must start that number. In case of error, Unmarshal returns a
|
||||||
|
// ParseError or UnexpectedMessageError.
|
||||||
|
func Unmarshal(data []byte, out interface{}) error { |
||||||
|
v := reflect.ValueOf(out).Elem() |
||||||
|
structType := v.Type() |
||||||
|
expectedType := typeTag(structType) |
||||||
|
if len(data) == 0 { |
||||||
|
return parseError(expectedType) |
||||||
|
} |
||||||
|
if expectedType > 0 { |
||||||
|
if data[0] != expectedType { |
||||||
|
return unexpectedMessageError(expectedType, data[0]) |
||||||
|
} |
||||||
|
data = data[1:] |
||||||
|
} |
||||||
|
|
||||||
|
var ok bool |
||||||
|
for i := 0; i < v.NumField(); i++ { |
||||||
|
field := v.Field(i) |
||||||
|
t := field.Type() |
||||||
|
switch t.Kind() { |
||||||
|
case reflect.Bool: |
||||||
|
if len(data) < 1 { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.SetBool(data[0] != 0) |
||||||
|
data = data[1:] |
||||||
|
case reflect.Array: |
||||||
|
if t.Elem().Kind() != reflect.Uint8 { |
||||||
|
return fieldError(structType, i, "array of unsupported type") |
||||||
|
} |
||||||
|
if len(data) < t.Len() { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
for j, n := 0, t.Len(); j < n; j++ { |
||||||
|
field.Index(j).Set(reflect.ValueOf(data[j])) |
||||||
|
} |
||||||
|
data = data[t.Len():] |
||||||
|
case reflect.Uint64: |
||||||
|
var u64 uint64 |
||||||
|
if u64, data, ok = parseUint64(data); !ok { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.SetUint(u64) |
||||||
|
case reflect.Uint32: |
||||||
|
var u32 uint32 |
||||||
|
if u32, data, ok = parseUint32(data); !ok { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.SetUint(uint64(u32)) |
||||||
|
case reflect.Uint8: |
||||||
|
if len(data) < 1 { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.SetUint(uint64(data[0])) |
||||||
|
data = data[1:] |
||||||
|
case reflect.String: |
||||||
|
var s []byte |
||||||
|
if s, data, ok = parseString(data); !ok { |
||||||
|
return fieldError(structType, i, "") |
||||||
|
} |
||||||
|
field.SetString(string(s)) |
||||||
|
case reflect.Slice: |
||||||
|
switch t.Elem().Kind() { |
||||||
|
case reflect.Uint8: |
||||||
|
if structType.Field(i).Tag.Get("ssh") == "rest" { |
||||||
|
field.Set(reflect.ValueOf(data)) |
||||||
|
data = nil |
||||||
|
} else { |
||||||
|
var s []byte |
||||||
|
if s, data, ok = parseString(data); !ok { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.Set(reflect.ValueOf(s)) |
||||||
|
} |
||||||
|
case reflect.String: |
||||||
|
var nl []string |
||||||
|
if nl, data, ok = parseNameList(data); !ok { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.Set(reflect.ValueOf(nl)) |
||||||
|
default: |
||||||
|
return fieldError(structType, i, "slice of unsupported type") |
||||||
|
} |
||||||
|
case reflect.Ptr: |
||||||
|
if t == bigIntType { |
||||||
|
var n *big.Int |
||||||
|
if n, data, ok = parseInt(data); !ok { |
||||||
|
return errShortRead |
||||||
|
} |
||||||
|
field.Set(reflect.ValueOf(n)) |
||||||
|
} else { |
||||||
|
return fieldError(structType, i, "pointer to unsupported type") |
||||||
|
} |
||||||
|
default: |
||||||
|
return fieldError(structType, i, "unsupported type") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if len(data) != 0 { |
||||||
|
return parseError(expectedType) |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Marshal serializes the message in msg to SSH wire format. The msg
|
||||||
|
// argument should be a struct or pointer to struct. If the first
|
||||||
|
// member has the "sshtype" tag set to a number in decimal, that
|
||||||
|
// number is prepended to the result. If the last of member has the
|
||||||
|
// "ssh" tag set to "rest", its contents are appended to the output.
|
||||||
|
func Marshal(msg interface{}) []byte { |
||||||
|
out := make([]byte, 0, 64) |
||||||
|
return marshalStruct(out, msg) |
||||||
|
} |
||||||
|
|
||||||
|
func marshalStruct(out []byte, msg interface{}) []byte { |
||||||
|
v := reflect.Indirect(reflect.ValueOf(msg)) |
||||||
|
msgType := typeTag(v.Type()) |
||||||
|
if msgType > 0 { |
||||||
|
out = append(out, msgType) |
||||||
|
} |
||||||
|
|
||||||
|
for i, n := 0, v.NumField(); i < n; i++ { |
||||||
|
field := v.Field(i) |
||||||
|
switch t := field.Type(); t.Kind() { |
||||||
|
case reflect.Bool: |
||||||
|
var v uint8 |
||||||
|
if field.Bool() { |
||||||
|
v = 1 |
||||||
|
} |
||||||
|
out = append(out, v) |
||||||
|
case reflect.Array: |
||||||
|
if t.Elem().Kind() != reflect.Uint8 { |
||||||
|
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) |
||||||
|
} |
||||||
|
for j, l := 0, t.Len(); j < l; j++ { |
||||||
|
out = append(out, uint8(field.Index(j).Uint())) |
||||||
|
} |
||||||
|
case reflect.Uint32: |
||||||
|
out = appendU32(out, uint32(field.Uint())) |
||||||
|
case reflect.Uint64: |
||||||
|
out = appendU64(out, uint64(field.Uint())) |
||||||
|
case reflect.Uint8: |
||||||
|
out = append(out, uint8(field.Uint())) |
||||||
|
case reflect.String: |
||||||
|
s := field.String() |
||||||
|
out = appendInt(out, len(s)) |
||||||
|
out = append(out, s...) |
||||||
|
case reflect.Slice: |
||||||
|
switch t.Elem().Kind() { |
||||||
|
case reflect.Uint8: |
||||||
|
if v.Type().Field(i).Tag.Get("ssh") != "rest" { |
||||||
|
out = appendInt(out, field.Len()) |
||||||
|
} |
||||||
|
out = append(out, field.Bytes()...) |
||||||
|
case reflect.String: |
||||||
|
offset := len(out) |
||||||
|
out = appendU32(out, 0) |
||||||
|
if n := field.Len(); n > 0 { |
||||||
|
for j := 0; j < n; j++ { |
||||||
|
f := field.Index(j) |
||||||
|
if j != 0 { |
||||||
|
out = append(out, ',') |
||||||
|
} |
||||||
|
out = append(out, f.String()...) |
||||||
|
} |
||||||
|
// overwrite length value
|
||||||
|
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) |
||||||
|
} |
||||||
|
default: |
||||||
|
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) |
||||||
|
} |
||||||
|
case reflect.Ptr: |
||||||
|
if t == bigIntType { |
||||||
|
var n *big.Int |
||||||
|
nValue := reflect.ValueOf(&n) |
||||||
|
nValue.Elem().Set(field) |
||||||
|
needed := intLength(n) |
||||||
|
oldLength := len(out) |
||||||
|
|
||||||
|
if cap(out)-len(out) < needed { |
||||||
|
newOut := make([]byte, len(out), 2*(len(out)+needed)) |
||||||
|
copy(newOut, out) |
||||||
|
out = newOut |
||||||
|
} |
||||||
|
out = out[:oldLength+needed] |
||||||
|
marshalInt(out[oldLength:], n) |
||||||
|
} else { |
||||||
|
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return out |
||||||
|
} |
||||||
|
|
||||||
|
var bigOne = big.NewInt(1) |
||||||
|
|
||||||
|
func parseString(in []byte) (out, rest []byte, ok bool) { |
||||||
|
if len(in) < 4 { |
||||||
|
return |
||||||
|
} |
||||||
|
length := binary.BigEndian.Uint32(in) |
||||||
|
in = in[4:] |
||||||
|
if uint32(len(in)) < length { |
||||||
|
return |
||||||
|
} |
||||||
|
out = in[:length] |
||||||
|
rest = in[length:] |
||||||
|
ok = true |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
comma = []byte{','} |
||||||
|
emptyNameList = []string{} |
||||||
|
) |
||||||
|
|
||||||
|
func parseNameList(in []byte) (out []string, rest []byte, ok bool) { |
||||||
|
contents, rest, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
if len(contents) == 0 { |
||||||
|
out = emptyNameList |
||||||
|
return |
||||||
|
} |
||||||
|
parts := bytes.Split(contents, comma) |
||||||
|
out = make([]string, len(parts)) |
||||||
|
for i, part := range parts { |
||||||
|
out[i] = string(part) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { |
||||||
|
contents, rest, ok := parseString(in) |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
out = new(big.Int) |
||||||
|
|
||||||
|
if len(contents) > 0 && contents[0]&0x80 == 0x80 { |
||||||
|
// This is a negative number
|
||||||
|
notBytes := make([]byte, len(contents)) |
||||||
|
for i := range notBytes { |
||||||
|
notBytes[i] = ^contents[i] |
||||||
|
} |
||||||
|
out.SetBytes(notBytes) |
||||||
|
out.Add(out, bigOne) |
||||||
|
out.Neg(out) |
||||||
|
} else { |
||||||
|
// Positive number
|
||||||
|
out.SetBytes(contents) |
||||||
|
} |
||||||
|
ok = true |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func parseUint32(in []byte) (uint32, []byte, bool) { |
||||||
|
if len(in) < 4 { |
||||||
|
return 0, nil, false |
||||||
|
} |
||||||
|
return binary.BigEndian.Uint32(in), in[4:], true |
||||||
|
} |
||||||
|
|
||||||
|
func parseUint64(in []byte) (uint64, []byte, bool) { |
||||||
|
if len(in) < 8 { |
||||||
|
return 0, nil, false |
||||||
|
} |
||||||
|
return binary.BigEndian.Uint64(in), in[8:], true |
||||||
|
} |
||||||
|
|
||||||
|
func intLength(n *big.Int) int { |
||||||
|
length := 4 /* length bytes */ |
||||||
|
if n.Sign() < 0 { |
||||||
|
nMinus1 := new(big.Int).Neg(n) |
||||||
|
nMinus1.Sub(nMinus1, bigOne) |
||||||
|
bitLen := nMinus1.BitLen() |
||||||
|
if bitLen%8 == 0 { |
||||||
|
// The number will need 0xff padding
|
||||||
|
length++ |
||||||
|
} |
||||||
|
length += (bitLen + 7) / 8 |
||||||
|
} else if n.Sign() == 0 { |
||||||
|
// A zero is the zero length string
|
||||||
|
} else { |
||||||
|
bitLen := n.BitLen() |
||||||
|
if bitLen%8 == 0 { |
||||||
|
// The number will need 0x00 padding
|
||||||
|
length++ |
||||||
|
} |
||||||
|
length += (bitLen + 7) / 8 |
||||||
|
} |
||||||
|
|
||||||
|
return length |
||||||
|
} |
||||||
|
|
||||||
|
func marshalUint32(to []byte, n uint32) []byte { |
||||||
|
binary.BigEndian.PutUint32(to, n) |
||||||
|
return to[4:] |
||||||
|
} |
||||||
|
|
||||||
|
func marshalUint64(to []byte, n uint64) []byte { |
||||||
|
binary.BigEndian.PutUint64(to, n) |
||||||
|
return to[8:] |
||||||
|
} |
||||||
|
|
||||||
|
func marshalInt(to []byte, n *big.Int) []byte { |
||||||
|
lengthBytes := to |
||||||
|
to = to[4:] |
||||||
|
length := 0 |
||||||
|
|
||||||
|
if n.Sign() < 0 { |
||||||
|
// A negative number has to be converted to two's-complement
|
||||||
|
// form. So we'll subtract 1 and invert. If the
|
||||||
|
// most-significant-bit isn't set then we'll need to pad the
|
||||||
|
// beginning with 0xff in order to keep the number negative.
|
||||||
|
nMinus1 := new(big.Int).Neg(n) |
||||||
|
nMinus1.Sub(nMinus1, bigOne) |
||||||
|
bytes := nMinus1.Bytes() |
||||||
|
for i := range bytes { |
||||||
|
bytes[i] ^= 0xff |
||||||
|
} |
||||||
|
if len(bytes) == 0 || bytes[0]&0x80 == 0 { |
||||||
|
to[0] = 0xff |
||||||
|
to = to[1:] |
||||||
|
length++ |
||||||
|
} |
||||||
|
nBytes := copy(to, bytes) |
||||||
|
to = to[nBytes:] |
||||||
|
length += nBytes |
||||||
|
} else if n.Sign() == 0 { |
||||||
|
// A zero is the zero length string
|
||||||
|
} else { |
||||||
|
bytes := n.Bytes() |
||||||
|
if len(bytes) > 0 && bytes[0]&0x80 != 0 { |
||||||
|
// We'll have to pad this with a 0x00 in order to
|
||||||
|
// stop it looking like a negative number.
|
||||||
|
to[0] = 0 |
||||||
|
to = to[1:] |
||||||
|
length++ |
||||||
|
} |
||||||
|
nBytes := copy(to, bytes) |
||||||
|
to = to[nBytes:] |
||||||
|
length += nBytes |
||||||
|
} |
||||||
|
|
||||||
|
lengthBytes[0] = byte(length >> 24) |
||||||
|
lengthBytes[1] = byte(length >> 16) |
||||||
|
lengthBytes[2] = byte(length >> 8) |
||||||
|
lengthBytes[3] = byte(length) |
||||||
|
return to |
||||||
|
} |
||||||
|
|
||||||
|
func writeInt(w io.Writer, n *big.Int) { |
||||||
|
length := intLength(n) |
||||||
|
buf := make([]byte, length) |
||||||
|
marshalInt(buf, n) |
||||||
|
w.Write(buf) |
||||||
|
} |
||||||
|
|
||||||
|
func writeString(w io.Writer, s []byte) { |
||||||
|
var lengthBytes [4]byte |
||||||
|
lengthBytes[0] = byte(len(s) >> 24) |
||||||
|
lengthBytes[1] = byte(len(s) >> 16) |
||||||
|
lengthBytes[2] = byte(len(s) >> 8) |
||||||
|
lengthBytes[3] = byte(len(s)) |
||||||
|
w.Write(lengthBytes[:]) |
||||||
|
w.Write(s) |
||||||
|
} |
||||||
|
|
||||||
|
func stringLength(n int) int { |
||||||
|
return 4 + n |
||||||
|
} |
||||||
|
|
||||||
|
func marshalString(to []byte, s []byte) []byte { |
||||||
|
to[0] = byte(len(s) >> 24) |
||||||
|
to[1] = byte(len(s) >> 16) |
||||||
|
to[2] = byte(len(s) >> 8) |
||||||
|
to[3] = byte(len(s)) |
||||||
|
to = to[4:] |
||||||
|
copy(to, s) |
||||||
|
return to[len(s):] |
||||||
|
} |
||||||
|
|
||||||
|
var bigIntType = reflect.TypeOf((*big.Int)(nil)) |
||||||
|
|
||||||
|
// Decode a packet into its corresponding message.
|
||||||
|
func decode(packet []byte) (interface{}, error) { |
||||||
|
var msg interface{} |
||||||
|
switch packet[0] { |
||||||
|
case msgDisconnect: |
||||||
|
msg = new(disconnectMsg) |
||||||
|
case msgServiceRequest: |
||||||
|
msg = new(serviceRequestMsg) |
||||||
|
case msgServiceAccept: |
||||||
|
msg = new(serviceAcceptMsg) |
||||||
|
case msgKexInit: |
||||||
|
msg = new(kexInitMsg) |
||||||
|
case msgKexDHInit: |
||||||
|
msg = new(kexDHInitMsg) |
||||||
|
case msgKexDHReply: |
||||||
|
msg = new(kexDHReplyMsg) |
||||||
|
case msgUserAuthRequest: |
||||||
|
msg = new(userAuthRequestMsg) |
||||||
|
case msgUserAuthFailure: |
||||||
|
msg = new(userAuthFailureMsg) |
||||||
|
case msgUserAuthPubKeyOk: |
||||||
|
msg = new(userAuthPubKeyOkMsg) |
||||||
|
case msgGlobalRequest: |
||||||
|
msg = new(globalRequestMsg) |
||||||
|
case msgRequestSuccess: |
||||||
|
msg = new(globalRequestSuccessMsg) |
||||||
|
case msgRequestFailure: |
||||||
|
msg = new(globalRequestFailureMsg) |
||||||
|
case msgChannelOpen: |
||||||
|
msg = new(channelOpenMsg) |
||||||
|
case msgChannelOpenConfirm: |
||||||
|
msg = new(channelOpenConfirmMsg) |
||||||
|
case msgChannelOpenFailure: |
||||||
|
msg = new(channelOpenFailureMsg) |
||||||
|
case msgChannelWindowAdjust: |
||||||
|
msg = new(windowAdjustMsg) |
||||||
|
case msgChannelEOF: |
||||||
|
msg = new(channelEOFMsg) |
||||||
|
case msgChannelClose: |
||||||
|
msg = new(channelCloseMsg) |
||||||
|
case msgChannelRequest: |
||||||
|
msg = new(channelRequestMsg) |
||||||
|
case msgChannelSuccess: |
||||||
|
msg = new(channelRequestSuccessMsg) |
||||||
|
case msgChannelFailure: |
||||||
|
msg = new(channelRequestFailureMsg) |
||||||
|
default: |
||||||
|
return nil, unexpectedMessageError(0, packet[0]) |
||||||
|
} |
||||||
|
if err := Unmarshal(packet, msg); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return msg, nil |
||||||
|
} |
@ -0,0 +1,254 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"math/big" |
||||||
|
"math/rand" |
||||||
|
"reflect" |
||||||
|
"testing" |
||||||
|
"testing/quick" |
||||||
|
) |
||||||
|
|
||||||
|
var intLengthTests = []struct { |
||||||
|
val, length int |
||||||
|
}{ |
||||||
|
{0, 4 + 0}, |
||||||
|
{1, 4 + 1}, |
||||||
|
{127, 4 + 1}, |
||||||
|
{128, 4 + 2}, |
||||||
|
{-1, 4 + 1}, |
||||||
|
} |
||||||
|
|
||||||
|
func TestIntLength(t *testing.T) { |
||||||
|
for _, test := range intLengthTests { |
||||||
|
v := new(big.Int).SetInt64(int64(test.val)) |
||||||
|
length := intLength(v) |
||||||
|
if length != test.length { |
||||||
|
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type msgAllTypes struct { |
||||||
|
Bool bool `sshtype:"21"` |
||||||
|
Array [16]byte |
||||||
|
Uint64 uint64 |
||||||
|
Uint32 uint32 |
||||||
|
Uint8 uint8 |
||||||
|
String string |
||||||
|
Strings []string |
||||||
|
Bytes []byte |
||||||
|
Int *big.Int |
||||||
|
Rest []byte `ssh:"rest"` |
||||||
|
} |
||||||
|
|
||||||
|
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { |
||||||
|
m := &msgAllTypes{} |
||||||
|
m.Bool = rand.Intn(2) == 1 |
||||||
|
randomBytes(m.Array[:], rand) |
||||||
|
m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) |
||||||
|
m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) |
||||||
|
m.Uint8 = uint8(rand.Intn(1 << 8)) |
||||||
|
m.String = string(m.Array[:]) |
||||||
|
m.Strings = randomNameList(rand) |
||||||
|
m.Bytes = m.Array[:] |
||||||
|
m.Int = randomInt(rand) |
||||||
|
m.Rest = m.Array[:] |
||||||
|
return reflect.ValueOf(m) |
||||||
|
} |
||||||
|
|
||||||
|
func TestMarshalUnmarshal(t *testing.T) { |
||||||
|
rand := rand.New(rand.NewSource(0)) |
||||||
|
iface := &msgAllTypes{} |
||||||
|
ty := reflect.ValueOf(iface).Type() |
||||||
|
|
||||||
|
n := 100 |
||||||
|
if testing.Short() { |
||||||
|
n = 5 |
||||||
|
} |
||||||
|
for j := 0; j < n; j++ { |
||||||
|
v, ok := quick.Value(ty, rand) |
||||||
|
if !ok { |
||||||
|
t.Errorf("failed to create value") |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
m1 := v.Elem().Interface() |
||||||
|
m2 := iface |
||||||
|
|
||||||
|
marshaled := Marshal(m1) |
||||||
|
if err := Unmarshal(marshaled, m2); err != nil { |
||||||
|
t.Errorf("Unmarshal %#v: %s", m1, err) |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
if !reflect.DeepEqual(v.Interface(), m2) { |
||||||
|
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestUnmarshalEmptyPacket(t *testing.T) { |
||||||
|
var b []byte |
||||||
|
var m channelRequestSuccessMsg |
||||||
|
if err := Unmarshal(b, &m); err == nil { |
||||||
|
t.Fatalf("unmarshal of empty slice succeeded") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestUnmarshalUnexpectedPacket(t *testing.T) { |
||||||
|
type S struct { |
||||||
|
I uint32 `sshtype:"43"` |
||||||
|
S string |
||||||
|
B bool |
||||||
|
} |
||||||
|
|
||||||
|
s := S{11, "hello", true} |
||||||
|
packet := Marshal(s) |
||||||
|
packet[0] = 42 |
||||||
|
roundtrip := S{} |
||||||
|
err := Unmarshal(packet, &roundtrip) |
||||||
|
if err == nil { |
||||||
|
t.Fatal("expected error, not nil") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMarshalPtr(t *testing.T) { |
||||||
|
s := struct { |
||||||
|
S string |
||||||
|
}{"hello"} |
||||||
|
|
||||||
|
m1 := Marshal(s) |
||||||
|
m2 := Marshal(&s) |
||||||
|
if !bytes.Equal(m1, m2) { |
||||||
|
t.Errorf("got %q, want %q for marshaled pointer", m2, m1) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestBareMarshalUnmarshal(t *testing.T) { |
||||||
|
type S struct { |
||||||
|
I uint32 |
||||||
|
S string |
||||||
|
B bool |
||||||
|
} |
||||||
|
|
||||||
|
s := S{42, "hello", true} |
||||||
|
packet := Marshal(s) |
||||||
|
roundtrip := S{} |
||||||
|
Unmarshal(packet, &roundtrip) |
||||||
|
|
||||||
|
if !reflect.DeepEqual(s, roundtrip) { |
||||||
|
t.Errorf("got %#v, want %#v", roundtrip, s) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestBareMarshal(t *testing.T) { |
||||||
|
type S2 struct { |
||||||
|
I uint32 |
||||||
|
} |
||||||
|
s := S2{42} |
||||||
|
packet := Marshal(s) |
||||||
|
i, rest, ok := parseUint32(packet) |
||||||
|
if len(rest) > 0 || !ok { |
||||||
|
t.Errorf("parseInt(%q): parse error", packet) |
||||||
|
} |
||||||
|
if i != s.I { |
||||||
|
t.Errorf("got %d, want %d", i, s.I) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestUnmarshalShortKexInitPacket(t *testing.T) { |
||||||
|
// This used to panic.
|
||||||
|
// Issue 11348
|
||||||
|
packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} |
||||||
|
kim := &kexInitMsg{} |
||||||
|
if err := Unmarshal(packet, kim); err == nil { |
||||||
|
t.Error("truncated packet unmarshaled without error") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func randomBytes(out []byte, rand *rand.Rand) { |
||||||
|
for i := 0; i < len(out); i++ { |
||||||
|
out[i] = byte(rand.Int31()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func randomNameList(rand *rand.Rand) []string { |
||||||
|
ret := make([]string, rand.Int31()&15) |
||||||
|
for i := range ret { |
||||||
|
s := make([]byte, 1+(rand.Int31()&15)) |
||||||
|
for j := range s { |
||||||
|
s[j] = 'a' + uint8(rand.Int31()&15) |
||||||
|
} |
||||||
|
ret[i] = string(s) |
||||||
|
} |
||||||
|
return ret |
||||||
|
} |
||||||
|
|
||||||
|
func randomInt(rand *rand.Rand) *big.Int { |
||||||
|
return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) |
||||||
|
} |
||||||
|
|
||||||
|
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
||||||
|
ki := &kexInitMsg{} |
||||||
|
randomBytes(ki.Cookie[:], rand) |
||||||
|
ki.KexAlgos = randomNameList(rand) |
||||||
|
ki.ServerHostKeyAlgos = randomNameList(rand) |
||||||
|
ki.CiphersClientServer = randomNameList(rand) |
||||||
|
ki.CiphersServerClient = randomNameList(rand) |
||||||
|
ki.MACsClientServer = randomNameList(rand) |
||||||
|
ki.MACsServerClient = randomNameList(rand) |
||||||
|
ki.CompressionClientServer = randomNameList(rand) |
||||||
|
ki.CompressionServerClient = randomNameList(rand) |
||||||
|
ki.LanguagesClientServer = randomNameList(rand) |
||||||
|
ki.LanguagesServerClient = randomNameList(rand) |
||||||
|
if rand.Int31()&1 == 1 { |
||||||
|
ki.FirstKexFollows = true |
||||||
|
} |
||||||
|
return reflect.ValueOf(ki) |
||||||
|
} |
||||||
|
|
||||||
|
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
||||||
|
dhi := &kexDHInitMsg{} |
||||||
|
dhi.X = randomInt(rand) |
||||||
|
return reflect.ValueOf(dhi) |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
||||||
|
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
||||||
|
|
||||||
|
_kexInit = Marshal(_kexInitMsg) |
||||||
|
_kexDHInit = Marshal(_kexDHInitMsg) |
||||||
|
) |
||||||
|
|
||||||
|
func BenchmarkMarshalKexInitMsg(b *testing.B) { |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
Marshal(_kexInitMsg) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func BenchmarkUnmarshalKexInitMsg(b *testing.B) { |
||||||
|
m := new(kexInitMsg) |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
Unmarshal(_kexInit, m) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func BenchmarkMarshalKexDHInitMsg(b *testing.B) { |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
Marshal(_kexDHInitMsg) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { |
||||||
|
m := new(kexDHInitMsg) |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
Unmarshal(_kexDHInit, m) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,356 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/binary" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"log" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
) |
||||||
|
|
||||||
|
// debugMux, if set, causes messages in the connection protocol to be
|
||||||
|
// logged.
|
||||||
|
const debugMux = false |
||||||
|
|
||||||
|
// chanList is a thread safe channel list.
|
||||||
|
type chanList struct { |
||||||
|
// protects concurrent access to chans
|
||||||
|
sync.Mutex |
||||||
|
|
||||||
|
// chans are indexed by the local id of the channel, which the
|
||||||
|
// other side should send in the PeersId field.
|
||||||
|
chans []*channel |
||||||
|
|
||||||
|
// This is a debugging aid: it offsets all IDs by this
|
||||||
|
// amount. This helps distinguish otherwise identical
|
||||||
|
// server/client muxes
|
||||||
|
offset uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// Assigns a channel ID to the given channel.
|
||||||
|
func (c *chanList) add(ch *channel) uint32 { |
||||||
|
c.Lock() |
||||||
|
defer c.Unlock() |
||||||
|
for i := range c.chans { |
||||||
|
if c.chans[i] == nil { |
||||||
|
c.chans[i] = ch |
||||||
|
return uint32(i) + c.offset |
||||||
|
} |
||||||
|
} |
||||||
|
c.chans = append(c.chans, ch) |
||||||
|
return uint32(len(c.chans)-1) + c.offset |
||||||
|
} |
||||||
|
|
||||||
|
// getChan returns the channel for the given ID.
|
||||||
|
func (c *chanList) getChan(id uint32) *channel { |
||||||
|
id -= c.offset |
||||||
|
|
||||||
|
c.Lock() |
||||||
|
defer c.Unlock() |
||||||
|
if id < uint32(len(c.chans)) { |
||||||
|
return c.chans[id] |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *chanList) remove(id uint32) { |
||||||
|
id -= c.offset |
||||||
|
c.Lock() |
||||||
|
if id < uint32(len(c.chans)) { |
||||||
|
c.chans[id] = nil |
||||||
|
} |
||||||
|
c.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// dropAll forgets all channels it knows, returning them in a slice.
|
||||||
|
func (c *chanList) dropAll() []*channel { |
||||||
|
c.Lock() |
||||||
|
defer c.Unlock() |
||||||
|
var r []*channel |
||||||
|
|
||||||
|
for _, ch := range c.chans { |
||||||
|
if ch == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
r = append(r, ch) |
||||||
|
} |
||||||
|
c.chans = nil |
||||||
|
return r |
||||||
|
} |
||||||
|
|
||||||
|
// mux represents the state for the SSH connection protocol, which
|
||||||
|
// multiplexes many channels onto a single packet transport.
|
||||||
|
type mux struct { |
||||||
|
conn packetConn |
||||||
|
chanList chanList |
||||||
|
|
||||||
|
incomingChannels chan NewChannel |
||||||
|
|
||||||
|
globalSentMu sync.Mutex |
||||||
|
globalResponses chan interface{} |
||||||
|
incomingRequests chan *Request |
||||||
|
|
||||||
|
errCond *sync.Cond |
||||||
|
err error |
||||||
|
} |
||||||
|
|
||||||
|
// When debugging, each new chanList instantiation has a different
|
||||||
|
// offset.
|
||||||
|
var globalOff uint32 |
||||||
|
|
||||||
|
func (m *mux) Wait() error { |
||||||
|
m.errCond.L.Lock() |
||||||
|
defer m.errCond.L.Unlock() |
||||||
|
for m.err == nil { |
||||||
|
m.errCond.Wait() |
||||||
|
} |
||||||
|
return m.err |
||||||
|
} |
||||||
|
|
||||||
|
// newMux returns a mux that runs over the given connection.
|
||||||
|
func newMux(p packetConn) *mux { |
||||||
|
m := &mux{ |
||||||
|
conn: p, |
||||||
|
incomingChannels: make(chan NewChannel, 16), |
||||||
|
globalResponses: make(chan interface{}, 1), |
||||||
|
incomingRequests: make(chan *Request, 16), |
||||||
|
errCond: newCond(), |
||||||
|
} |
||||||
|
if debugMux { |
||||||
|
m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
||||||
|
} |
||||||
|
|
||||||
|
go m.loop() |
||||||
|
return m |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) sendMessage(msg interface{}) error { |
||||||
|
p := Marshal(msg) |
||||||
|
return m.conn.writePacket(p) |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { |
||||||
|
if wantReply { |
||||||
|
m.globalSentMu.Lock() |
||||||
|
defer m.globalSentMu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
if err := m.sendMessage(globalRequestMsg{ |
||||||
|
Type: name, |
||||||
|
WantReply: wantReply, |
||||||
|
Data: payload, |
||||||
|
}); err != nil { |
||||||
|
return false, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if !wantReply { |
||||||
|
return false, nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
msg, ok := <-m.globalResponses |
||||||
|
if !ok { |
||||||
|
return false, nil, io.EOF |
||||||
|
} |
||||||
|
switch msg := msg.(type) { |
||||||
|
case *globalRequestFailureMsg: |
||||||
|
return false, msg.Data, nil |
||||||
|
case *globalRequestSuccessMsg: |
||||||
|
return true, msg.Data, nil |
||||||
|
default: |
||||||
|
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ackRequest must be called after processing a global request that
|
||||||
|
// has WantReply set.
|
||||||
|
func (m *mux) ackRequest(ok bool, data []byte) error { |
||||||
|
if ok { |
||||||
|
return m.sendMessage(globalRequestSuccessMsg{Data: data}) |
||||||
|
} |
||||||
|
return m.sendMessage(globalRequestFailureMsg{Data: data}) |
||||||
|
} |
||||||
|
|
||||||
|
// TODO(hanwen): Disconnect is a transport layer message. We should
|
||||||
|
// probably send and receive Disconnect somewhere in the transport
|
||||||
|
// code.
|
||||||
|
|
||||||
|
// Disconnect sends a disconnect message.
|
||||||
|
func (m *mux) Disconnect(reason uint32, message string) error { |
||||||
|
return m.sendMessage(disconnectMsg{ |
||||||
|
Reason: reason, |
||||||
|
Message: message, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) Close() error { |
||||||
|
return m.conn.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// loop runs the connection machine. It will process packets until an
|
||||||
|
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
||||||
|
func (m *mux) loop() { |
||||||
|
var err error |
||||||
|
for err == nil { |
||||||
|
err = m.onePacket() |
||||||
|
} |
||||||
|
|
||||||
|
for _, ch := range m.chanList.dropAll() { |
||||||
|
ch.close() |
||||||
|
} |
||||||
|
|
||||||
|
close(m.incomingChannels) |
||||||
|
close(m.incomingRequests) |
||||||
|
close(m.globalResponses) |
||||||
|
|
||||||
|
m.conn.Close() |
||||||
|
|
||||||
|
m.errCond.L.Lock() |
||||||
|
m.err = err |
||||||
|
m.errCond.Broadcast() |
||||||
|
m.errCond.L.Unlock() |
||||||
|
|
||||||
|
if debugMux { |
||||||
|
log.Println("loop exit", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// onePacket reads and processes one packet.
|
||||||
|
func (m *mux) onePacket() error { |
||||||
|
packet, err := m.conn.readPacket() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if debugMux { |
||||||
|
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { |
||||||
|
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) |
||||||
|
} else { |
||||||
|
p, _ := decode(packet) |
||||||
|
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
switch packet[0] { |
||||||
|
case msgNewKeys: |
||||||
|
// Ignore notification of key change.
|
||||||
|
return nil |
||||||
|
case msgDisconnect: |
||||||
|
return m.handleDisconnect(packet) |
||||||
|
case msgChannelOpen: |
||||||
|
return m.handleChannelOpen(packet) |
||||||
|
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
||||||
|
return m.handleGlobalPacket(packet) |
||||||
|
} |
||||||
|
|
||||||
|
// assume a channel packet.
|
||||||
|
if len(packet) < 5 { |
||||||
|
return parseError(packet[0]) |
||||||
|
} |
||||||
|
id := binary.BigEndian.Uint32(packet[1:]) |
||||||
|
ch := m.chanList.getChan(id) |
||||||
|
if ch == nil { |
||||||
|
return fmt.Errorf("ssh: invalid channel %d", id) |
||||||
|
} |
||||||
|
|
||||||
|
return ch.handlePacket(packet) |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) handleDisconnect(packet []byte) error { |
||||||
|
var d disconnectMsg |
||||||
|
if err := Unmarshal(packet, &d); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if debugMux { |
||||||
|
log.Printf("caught disconnect: %v", d) |
||||||
|
} |
||||||
|
return &d |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) handleGlobalPacket(packet []byte) error { |
||||||
|
msg, err := decode(packet) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
switch msg := msg.(type) { |
||||||
|
case *globalRequestMsg: |
||||||
|
m.incomingRequests <- &Request{ |
||||||
|
Type: msg.Type, |
||||||
|
WantReply: msg.WantReply, |
||||||
|
Payload: msg.Data, |
||||||
|
mux: m, |
||||||
|
} |
||||||
|
case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
||||||
|
m.globalResponses <- msg |
||||||
|
default: |
||||||
|
panic(fmt.Sprintf("not a global message %#v", msg)) |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// handleChannelOpen schedules a channel to be Accept()ed.
|
||||||
|
func (m *mux) handleChannelOpen(packet []byte) error { |
||||||
|
var msg channelOpenMsg |
||||||
|
if err := Unmarshal(packet, &msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||||
|
failMsg := channelOpenFailureMsg{ |
||||||
|
PeersId: msg.PeersId, |
||||||
|
Reason: ConnectionFailed, |
||||||
|
Message: "invalid request", |
||||||
|
Language: "en_US.UTF-8", |
||||||
|
} |
||||||
|
return m.sendMessage(failMsg) |
||||||
|
} |
||||||
|
|
||||||
|
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) |
||||||
|
c.remoteId = msg.PeersId |
||||||
|
c.maxRemotePayload = msg.MaxPacketSize |
||||||
|
c.remoteWin.add(msg.PeersWindow) |
||||||
|
m.incomingChannels <- c |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { |
||||||
|
ch, err := m.openChannel(chanType, extra) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return ch, ch.incomingRequests, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { |
||||||
|
ch := m.newChannel(chanType, channelOutbound, extra) |
||||||
|
|
||||||
|
ch.maxIncomingPayload = channelMaxPacket |
||||||
|
|
||||||
|
open := channelOpenMsg{ |
||||||
|
ChanType: chanType, |
||||||
|
PeersWindow: ch.myWindow, |
||||||
|
MaxPacketSize: ch.maxIncomingPayload, |
||||||
|
TypeSpecificData: extra, |
||||||
|
PeersId: ch.localId, |
||||||
|
} |
||||||
|
if err := m.sendMessage(open); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
switch msg := (<-ch.msg).(type) { |
||||||
|
case *channelOpenConfirmMsg: |
||||||
|
return ch, nil |
||||||
|
case *channelOpenFailureMsg: |
||||||
|
return nil, &OpenChannelError{msg.Reason, msg.Message} |
||||||
|
default: |
||||||
|
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,525 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
"sync" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func muxPair() (*mux, *mux) { |
||||||
|
a, b := memPipe() |
||||||
|
|
||||||
|
s := newMux(a) |
||||||
|
c := newMux(b) |
||||||
|
|
||||||
|
return s, c |
||||||
|
} |
||||||
|
|
||||||
|
// Returns both ends of a channel, and the mux for the the 2nd
|
||||||
|
// channel.
|
||||||
|
func channelPair(t *testing.T) (*channel, *channel, *mux) { |
||||||
|
c, s := muxPair() |
||||||
|
|
||||||
|
res := make(chan *channel, 1) |
||||||
|
go func() { |
||||||
|
newCh, ok := <-s.incomingChannels |
||||||
|
if !ok { |
||||||
|
t.Fatalf("No incoming channel") |
||||||
|
} |
||||||
|
if newCh.ChannelType() != "chan" { |
||||||
|
t.Fatalf("got type %q want chan", newCh.ChannelType()) |
||||||
|
} |
||||||
|
ch, _, err := newCh.Accept() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Accept %v", err) |
||||||
|
} |
||||||
|
res <- ch.(*channel) |
||||||
|
}() |
||||||
|
|
||||||
|
ch, err := c.openChannel("chan", nil) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("OpenChannel: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
return <-res, ch, c |
||||||
|
} |
||||||
|
|
||||||
|
// Test that stderr and stdout can be addressed from different
|
||||||
|
// goroutines. This is intended for use with the race detector.
|
||||||
|
func TestMuxChannelExtendedThreadSafety(t *testing.T) { |
||||||
|
writer, reader, mux := channelPair(t) |
||||||
|
defer writer.Close() |
||||||
|
defer reader.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
var wr, rd sync.WaitGroup |
||||||
|
magic := "hello world" |
||||||
|
|
||||||
|
wr.Add(2) |
||||||
|
go func() { |
||||||
|
io.WriteString(writer, magic) |
||||||
|
wr.Done() |
||||||
|
}() |
||||||
|
go func() { |
||||||
|
io.WriteString(writer.Stderr(), magic) |
||||||
|
wr.Done() |
||||||
|
}() |
||||||
|
|
||||||
|
rd.Add(2) |
||||||
|
go func() { |
||||||
|
c, err := ioutil.ReadAll(reader) |
||||||
|
if string(c) != magic { |
||||||
|
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) |
||||||
|
} |
||||||
|
rd.Done() |
||||||
|
}() |
||||||
|
go func() { |
||||||
|
c, err := ioutil.ReadAll(reader.Stderr()) |
||||||
|
if string(c) != magic { |
||||||
|
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) |
||||||
|
} |
||||||
|
rd.Done() |
||||||
|
}() |
||||||
|
|
||||||
|
wr.Wait() |
||||||
|
writer.CloseWrite() |
||||||
|
rd.Wait() |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxReadWrite(t *testing.T) { |
||||||
|
s, c, mux := channelPair(t) |
||||||
|
defer s.Close() |
||||||
|
defer c.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
magic := "hello world" |
||||||
|
magicExt := "hello stderr" |
||||||
|
go func() { |
||||||
|
_, err := s.Write([]byte(magic)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Write: %v", err) |
||||||
|
} |
||||||
|
_, err = s.Extended(1).Write([]byte(magicExt)) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Write: %v", err) |
||||||
|
} |
||||||
|
err = s.Close() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Close: %v", err) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
var buf [1024]byte |
||||||
|
n, err := c.Read(buf[:]) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("server Read: %v", err) |
||||||
|
} |
||||||
|
got := string(buf[:n]) |
||||||
|
if got != magic { |
||||||
|
t.Fatalf("server: got %q want %q", got, magic) |
||||||
|
} |
||||||
|
|
||||||
|
n, err = c.Extended(1).Read(buf[:]) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("server Read: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
got = string(buf[:n]) |
||||||
|
if got != magicExt { |
||||||
|
t.Fatalf("server: got %q want %q", got, magic) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxChannelOverflow(t *testing.T) { |
||||||
|
reader, writer, mux := channelPair(t) |
||||||
|
defer reader.Close() |
||||||
|
defer writer.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
wDone := make(chan int, 1) |
||||||
|
go func() { |
||||||
|
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||||
|
t.Errorf("could not fill window: %v", err) |
||||||
|
} |
||||||
|
writer.Write(make([]byte, 1)) |
||||||
|
wDone <- 1 |
||||||
|
}() |
||||||
|
writer.remoteWin.waitWriterBlocked() |
||||||
|
|
||||||
|
// Send 1 byte.
|
||||||
|
packet := make([]byte, 1+4+4+1) |
||||||
|
packet[0] = msgChannelData |
||||||
|
marshalUint32(packet[1:], writer.remoteId) |
||||||
|
marshalUint32(packet[5:], uint32(1)) |
||||||
|
packet[9] = 42 |
||||||
|
|
||||||
|
if err := writer.mux.conn.writePacket(packet); err != nil { |
||||||
|
t.Errorf("could not send packet") |
||||||
|
} |
||||||
|
if _, err := reader.SendRequest("hello", true, nil); err == nil { |
||||||
|
t.Errorf("SendRequest succeeded.") |
||||||
|
} |
||||||
|
<-wDone |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxChannelCloseWriteUnblock(t *testing.T) { |
||||||
|
reader, writer, mux := channelPair(t) |
||||||
|
defer reader.Close() |
||||||
|
defer writer.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
wDone := make(chan int, 1) |
||||||
|
go func() { |
||||||
|
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||||
|
t.Errorf("could not fill window: %v", err) |
||||||
|
} |
||||||
|
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { |
||||||
|
t.Errorf("got %v, want EOF for unblock write", err) |
||||||
|
} |
||||||
|
wDone <- 1 |
||||||
|
}() |
||||||
|
|
||||||
|
writer.remoteWin.waitWriterBlocked() |
||||||
|
reader.Close() |
||||||
|
<-wDone |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxConnectionCloseWriteUnblock(t *testing.T) { |
||||||
|
reader, writer, mux := channelPair(t) |
||||||
|
defer reader.Close() |
||||||
|
defer writer.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
wDone := make(chan int, 1) |
||||||
|
go func() { |
||||||
|
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||||
|
t.Errorf("could not fill window: %v", err) |
||||||
|
} |
||||||
|
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { |
||||||
|
t.Errorf("got %v, want EOF for unblock write", err) |
||||||
|
} |
||||||
|
wDone <- 1 |
||||||
|
}() |
||||||
|
|
||||||
|
writer.remoteWin.waitWriterBlocked() |
||||||
|
mux.Close() |
||||||
|
<-wDone |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxReject(t *testing.T) { |
||||||
|
client, server := muxPair() |
||||||
|
defer server.Close() |
||||||
|
defer client.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
ch, ok := <-server.incomingChannels |
||||||
|
if !ok { |
||||||
|
t.Fatalf("Accept") |
||||||
|
} |
||||||
|
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { |
||||||
|
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) |
||||||
|
} |
||||||
|
ch.Reject(RejectionReason(42), "message") |
||||||
|
}() |
||||||
|
|
||||||
|
ch, err := client.openChannel("ch", []byte("extra")) |
||||||
|
if ch != nil { |
||||||
|
t.Fatal("openChannel not rejected") |
||||||
|
} |
||||||
|
|
||||||
|
ocf, ok := err.(*OpenChannelError) |
||||||
|
if !ok { |
||||||
|
t.Errorf("got %#v want *OpenChannelError", err) |
||||||
|
} else if ocf.Reason != 42 || ocf.Message != "message" { |
||||||
|
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") |
||||||
|
} |
||||||
|
|
||||||
|
want := "ssh: rejected: unknown reason 42 (message)" |
||||||
|
if err.Error() != want { |
||||||
|
t.Errorf("got %q, want %q", err.Error(), want) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxChannelRequest(t *testing.T) { |
||||||
|
client, server, mux := channelPair(t) |
||||||
|
defer server.Close() |
||||||
|
defer client.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
var received int |
||||||
|
var wg sync.WaitGroup |
||||||
|
wg.Add(1) |
||||||
|
go func() { |
||||||
|
for r := range server.incomingRequests { |
||||||
|
received++ |
||||||
|
r.Reply(r.Type == "yes", nil) |
||||||
|
} |
||||||
|
wg.Done() |
||||||
|
}() |
||||||
|
_, err := client.SendRequest("yes", false, nil) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("SendRequest: %v", err) |
||||||
|
} |
||||||
|
ok, err := client.SendRequest("yes", true, nil) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("SendRequest: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if !ok { |
||||||
|
t.Errorf("SendRequest(yes): %v", ok) |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
ok, err = client.SendRequest("no", true, nil) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("SendRequest: %v", err) |
||||||
|
} |
||||||
|
if ok { |
||||||
|
t.Errorf("SendRequest(no): %v", ok) |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
client.Close() |
||||||
|
wg.Wait() |
||||||
|
|
||||||
|
if received != 3 { |
||||||
|
t.Errorf("got %d requests, want %d", received, 3) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxGlobalRequest(t *testing.T) { |
||||||
|
clientMux, serverMux := muxPair() |
||||||
|
defer serverMux.Close() |
||||||
|
defer clientMux.Close() |
||||||
|
|
||||||
|
var seen bool |
||||||
|
go func() { |
||||||
|
for r := range serverMux.incomingRequests { |
||||||
|
seen = seen || r.Type == "peek" |
||||||
|
if r.WantReply { |
||||||
|
err := r.Reply(r.Type == "yes", |
||||||
|
append([]byte(r.Type), r.Payload...)) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("AckRequest: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
_, _, err := clientMux.SendRequest("peek", false, nil) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("SendRequest: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) |
||||||
|
if !ok || string(data) != "yesa" || err != nil { |
||||||
|
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
||||||
|
ok, data, err) |
||||||
|
} |
||||||
|
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { |
||||||
|
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
||||||
|
ok, data, err) |
||||||
|
} |
||||||
|
|
||||||
|
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { |
||||||
|
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", |
||||||
|
ok, data, err) |
||||||
|
} |
||||||
|
|
||||||
|
clientMux.Disconnect(0, "") |
||||||
|
if !seen { |
||||||
|
t.Errorf("never saw 'peek' request") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxGlobalRequestUnblock(t *testing.T) { |
||||||
|
clientMux, serverMux := muxPair() |
||||||
|
defer serverMux.Close() |
||||||
|
defer clientMux.Close() |
||||||
|
|
||||||
|
result := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
_, _, err := clientMux.SendRequest("hello", true, nil) |
||||||
|
result <- err |
||||||
|
}() |
||||||
|
|
||||||
|
<-serverMux.incomingRequests |
||||||
|
serverMux.conn.Close() |
||||||
|
err := <-result |
||||||
|
|
||||||
|
if err != io.EOF { |
||||||
|
t.Errorf("want EOF, got %v", io.EOF) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxChannelRequestUnblock(t *testing.T) { |
||||||
|
a, b, connB := channelPair(t) |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
defer connB.Close() |
||||||
|
|
||||||
|
result := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
_, err := a.SendRequest("hello", true, nil) |
||||||
|
result <- err |
||||||
|
}() |
||||||
|
|
||||||
|
<-b.incomingRequests |
||||||
|
connB.conn.Close() |
||||||
|
err := <-result |
||||||
|
|
||||||
|
if err != io.EOF { |
||||||
|
t.Errorf("want EOF, got %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxDisconnect(t *testing.T) { |
||||||
|
a, b := muxPair() |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
for r := range b.incomingRequests { |
||||||
|
r.Reply(true, nil) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
a.Disconnect(42, "whatever") |
||||||
|
ok, _, err := a.SendRequest("hello", true, nil) |
||||||
|
if ok || err == nil { |
||||||
|
t.Errorf("got reply after disconnecting") |
||||||
|
} |
||||||
|
err = b.Wait() |
||||||
|
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { |
||||||
|
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxCloseChannel(t *testing.T) { |
||||||
|
r, w, mux := channelPair(t) |
||||||
|
defer mux.Close() |
||||||
|
defer r.Close() |
||||||
|
defer w.Close() |
||||||
|
|
||||||
|
result := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
var b [1024]byte |
||||||
|
_, err := r.Read(b[:]) |
||||||
|
result <- err |
||||||
|
}() |
||||||
|
if err := w.Close(); err != nil { |
||||||
|
t.Errorf("w.Close: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := w.Write([]byte("hello")); err != io.EOF { |
||||||
|
t.Errorf("got err %v, want io.EOF after Close", err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := <-result; err != io.EOF { |
||||||
|
t.Errorf("got %v (%T), want io.EOF", err, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxCloseWriteChannel(t *testing.T) { |
||||||
|
r, w, mux := channelPair(t) |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
result := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
var b [1024]byte |
||||||
|
_, err := r.Read(b[:]) |
||||||
|
result <- err |
||||||
|
}() |
||||||
|
if err := w.CloseWrite(); err != nil { |
||||||
|
t.Errorf("w.CloseWrite: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := w.Write([]byte("hello")); err != io.EOF { |
||||||
|
t.Errorf("got err %v, want io.EOF after CloseWrite", err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := <-result; err != io.EOF { |
||||||
|
t.Errorf("got %v (%T), want io.EOF", err, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxInvalidRecord(t *testing.T) { |
||||||
|
a, b := muxPair() |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
|
||||||
|
packet := make([]byte, 1+4+4+1) |
||||||
|
packet[0] = msgChannelData |
||||||
|
marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
||||||
|
marshalUint32(packet[5:], 1) |
||||||
|
packet[9] = 42 |
||||||
|
|
||||||
|
a.conn.writePacket(packet) |
||||||
|
go a.SendRequest("hello", false, nil) |
||||||
|
// 'a' wrote an invalid packet, so 'b' has exited.
|
||||||
|
req, ok := <-b.incomingRequests |
||||||
|
if ok { |
||||||
|
t.Errorf("got request %#v after receiving invalid packet", req) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestZeroWindowAdjust(t *testing.T) { |
||||||
|
a, b, mux := channelPair(t) |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
io.WriteString(a, "hello") |
||||||
|
// bogus adjust.
|
||||||
|
a.sendMessage(windowAdjustMsg{}) |
||||||
|
io.WriteString(a, "world") |
||||||
|
a.Close() |
||||||
|
}() |
||||||
|
|
||||||
|
want := "helloworld" |
||||||
|
c, _ := ioutil.ReadAll(b) |
||||||
|
if string(c) != want { |
||||||
|
t.Errorf("got %q want %q", c, want) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMuxMaxPacketSize(t *testing.T) { |
||||||
|
a, b, mux := channelPair(t) |
||||||
|
defer a.Close() |
||||||
|
defer b.Close() |
||||||
|
defer mux.Close() |
||||||
|
|
||||||
|
large := make([]byte, a.maxRemotePayload+1) |
||||||
|
packet := make([]byte, 1+4+4+1+len(large)) |
||||||
|
packet[0] = msgChannelData |
||||||
|
marshalUint32(packet[1:], a.remoteId) |
||||||
|
marshalUint32(packet[5:], uint32(len(large))) |
||||||
|
packet[9] = 42 |
||||||
|
|
||||||
|
if err := a.mux.conn.writePacket(packet); err != nil { |
||||||
|
t.Errorf("could not send packet") |
||||||
|
} |
||||||
|
|
||||||
|
go a.SendRequest("hello", false, nil) |
||||||
|
|
||||||
|
_, ok := <-b.incomingRequests |
||||||
|
if ok { |
||||||
|
t.Errorf("connection still alive after receiving large packet.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Don't ship code with debug=true.
|
||||||
|
func TestDebug(t *testing.T) { |
||||||
|
if debugMux { |
||||||
|
t.Error("mux debug switched on") |
||||||
|
} |
||||||
|
if debugHandshake { |
||||||
|
t.Error("handshake debug switched on") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,493 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
) |
||||||
|
|
||||||
|
// The Permissions type holds fine-grained permissions that are
|
||||||
|
// specific to a user or a specific authentication method for a
|
||||||
|
// user. Permissions, except for "source-address", must be enforced in
|
||||||
|
// the server application layer, after successful authentication. The
|
||||||
|
// Permissions are passed on in ServerConn so a server implementation
|
||||||
|
// can honor them.
|
||||||
|
type Permissions struct { |
||||||
|
// Critical options restrict default permissions. Common
|
||||||
|
// restrictions are "source-address" and "force-command". If
|
||||||
|
// the server cannot enforce the restriction, or does not
|
||||||
|
// recognize it, the user should not authenticate.
|
||||||
|
CriticalOptions map[string]string |
||||||
|
|
||||||
|
// Extensions are extra functionality that the server may
|
||||||
|
// offer on authenticated connections. Common extensions are
|
||||||
|
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
|
||||||
|
// support for an extension does not preclude authenticating a
|
||||||
|
// user.
|
||||||
|
Extensions map[string]string |
||||||
|
} |
||||||
|
|
||||||
|
// ServerConfig holds server specific configuration data.
|
||||||
|
type ServerConfig struct { |
||||||
|
// Config contains configuration shared between client and server.
|
||||||
|
Config |
||||||
|
|
||||||
|
hostKeys []Signer |
||||||
|
|
||||||
|
// NoClientAuth is true if clients are allowed to connect without
|
||||||
|
// authenticating.
|
||||||
|
NoClientAuth bool |
||||||
|
|
||||||
|
// PasswordCallback, if non-nil, is called when a user
|
||||||
|
// attempts to authenticate using a password.
|
||||||
|
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) |
||||||
|
|
||||||
|
// PublicKeyCallback, if non-nil, is called when a client attempts public
|
||||||
|
// key authentication. It must return true if the given public key is
|
||||||
|
// valid for the given user. For example, see CertChecker.Authenticate.
|
||||||
|
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||||
|
|
||||||
|
// KeyboardInteractiveCallback, if non-nil, is called when
|
||||||
|
// keyboard-interactive authentication is selected (RFC
|
||||||
|
// 4256). The client object's Challenge function should be
|
||||||
|
// used to query the user. The callback may offer multiple
|
||||||
|
// Challenge rounds. To avoid information leaks, the client
|
||||||
|
// should be presented a challenge even if the user is
|
||||||
|
// unknown.
|
||||||
|
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) |
||||||
|
|
||||||
|
// AuthLogCallback, if non-nil, is called to log all authentication
|
||||||
|
// attempts.
|
||||||
|
AuthLogCallback func(conn ConnMetadata, method string, err error) |
||||||
|
|
||||||
|
// ServerVersion is the version identification string to
|
||||||
|
// announce in the public handshake.
|
||||||
|
// If empty, a reasonable default is used.
|
||||||
|
ServerVersion string |
||||||
|
} |
||||||
|
|
||||||
|
// AddHostKey adds a private key as a host key. If an existing host
|
||||||
|
// key exists with the same algorithm, it is overwritten. Each server
|
||||||
|
// config must have at least one host key.
|
||||||
|
func (s *ServerConfig) AddHostKey(key Signer) { |
||||||
|
for i, k := range s.hostKeys { |
||||||
|
if k.PublicKey().Type() == key.PublicKey().Type() { |
||||||
|
s.hostKeys[i] = key |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
s.hostKeys = append(s.hostKeys, key) |
||||||
|
} |
||||||
|
|
||||||
|
// cachedPubKey contains the results of querying whether a public key is
|
||||||
|
// acceptable for a user.
|
||||||
|
type cachedPubKey struct { |
||||||
|
user string |
||||||
|
pubKeyData []byte |
||||||
|
result error |
||||||
|
perms *Permissions |
||||||
|
} |
||||||
|
|
||||||
|
const maxCachedPubKeys = 16 |
||||||
|
|
||||||
|
// pubKeyCache caches tests for public keys. Since SSH clients
|
||||||
|
// will query whether a public key is acceptable before attempting to
|
||||||
|
// authenticate with it, we end up with duplicate queries for public
|
||||||
|
// key validity. The cache only applies to a single ServerConn.
|
||||||
|
type pubKeyCache struct { |
||||||
|
keys []cachedPubKey |
||||||
|
} |
||||||
|
|
||||||
|
// get returns the result for a given user/algo/key tuple.
|
||||||
|
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { |
||||||
|
for _, k := range c.keys { |
||||||
|
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { |
||||||
|
return k, true |
||||||
|
} |
||||||
|
} |
||||||
|
return cachedPubKey{}, false |
||||||
|
} |
||||||
|
|
||||||
|
// add adds the given tuple to the cache.
|
||||||
|
func (c *pubKeyCache) add(candidate cachedPubKey) { |
||||||
|
if len(c.keys) < maxCachedPubKeys { |
||||||
|
c.keys = append(c.keys, candidate) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ServerConn is an authenticated SSH connection, as seen from the
|
||||||
|
// server
|
||||||
|
type ServerConn struct { |
||||||
|
Conn |
||||||
|
|
||||||
|
// If the succeeding authentication callback returned a
|
||||||
|
// non-nil Permissions pointer, it is stored here.
|
||||||
|
Permissions *Permissions |
||||||
|
} |
||||||
|
|
||||||
|
// NewServerConn starts a new SSH server with c as the underlying
|
||||||
|
// transport. It starts with a handshake and, if the handshake is
|
||||||
|
// unsuccessful, it closes the connection and returns an error. The
|
||||||
|
// Request and NewChannel channels must be serviced, or the connection
|
||||||
|
// will hang.
|
||||||
|
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { |
||||||
|
fullConf := *config |
||||||
|
fullConf.SetDefaults() |
||||||
|
s := &connection{ |
||||||
|
sshConn: sshConn{conn: c}, |
||||||
|
} |
||||||
|
perms, err := s.serverHandshake(&fullConf) |
||||||
|
if err != nil { |
||||||
|
c.Close() |
||||||
|
return nil, nil, nil, err |
||||||
|
} |
||||||
|
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil |
||||||
|
} |
||||||
|
|
||||||
|
// signAndMarshal signs the data with the appropriate algorithm,
|
||||||
|
// and serializes the result in SSH wire format.
|
||||||
|
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { |
||||||
|
sig, err := k.Sign(rand, data) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return Marshal(sig), nil |
||||||
|
} |
||||||
|
|
||||||
|
// handshake performs key exchange and user authentication.
|
||||||
|
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { |
||||||
|
if len(config.hostKeys) == 0 { |
||||||
|
return nil, errors.New("ssh: server has no host keys") |
||||||
|
} |
||||||
|
|
||||||
|
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { |
||||||
|
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||||
|
} |
||||||
|
|
||||||
|
if config.ServerVersion != "" { |
||||||
|
s.serverVersion = []byte(config.ServerVersion) |
||||||
|
} else { |
||||||
|
s.serverVersion = []byte(packageVersion) |
||||||
|
} |
||||||
|
var err error |
||||||
|
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) |
||||||
|
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) |
||||||
|
|
||||||
|
if err := s.transport.requestKeyChange(); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if packet, err := s.transport.readPacket(); err != nil { |
||||||
|
return nil, err |
||||||
|
} else if packet[0] != msgNewKeys { |
||||||
|
return nil, unexpectedMessageError(msgNewKeys, packet[0]) |
||||||
|
} |
||||||
|
|
||||||
|
// We just did the key change, so the session ID is established.
|
||||||
|
s.sessionID = s.transport.getSessionID() |
||||||
|
|
||||||
|
var packet []byte |
||||||
|
if packet, err = s.transport.readPacket(); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var serviceRequest serviceRequestMsg |
||||||
|
if err = Unmarshal(packet, &serviceRequest); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if serviceRequest.Service != serviceUserAuth { |
||||||
|
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") |
||||||
|
} |
||||||
|
serviceAccept := serviceAcceptMsg{ |
||||||
|
Service: serviceUserAuth, |
||||||
|
} |
||||||
|
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
perms, err := s.serverAuthenticate(config) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
s.mux = newMux(s.transport) |
||||||
|
return perms, err |
||||||
|
} |
||||||
|
|
||||||
|
func isAcceptableAlgo(algo string) bool { |
||||||
|
switch algo { |
||||||
|
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||||
|
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||||
|
return true |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
func checkSourceAddress(addr net.Addr, sourceAddr string) error { |
||||||
|
if addr == nil { |
||||||
|
return errors.New("ssh: no address known for client, but source-address match required") |
||||||
|
} |
||||||
|
|
||||||
|
tcpAddr, ok := addr.(*net.TCPAddr) |
||||||
|
if !ok { |
||||||
|
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) |
||||||
|
} |
||||||
|
|
||||||
|
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { |
||||||
|
if bytes.Equal(allowedIP, tcpAddr.IP) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
} else { |
||||||
|
_, ipNet, err := net.ParseCIDR(sourceAddr) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) |
||||||
|
} |
||||||
|
|
||||||
|
if ipNet.Contains(tcpAddr.IP) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { |
||||||
|
var err error |
||||||
|
var cache pubKeyCache |
||||||
|
var perms *Permissions |
||||||
|
|
||||||
|
userAuthLoop: |
||||||
|
for { |
||||||
|
var userAuthReq userAuthRequestMsg |
||||||
|
if packet, err := s.transport.readPacket(); err != nil { |
||||||
|
return nil, err |
||||||
|
} else if err = Unmarshal(packet, &userAuthReq); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
if userAuthReq.Service != serviceSSH { |
||||||
|
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) |
||||||
|
} |
||||||
|
|
||||||
|
s.user = userAuthReq.User |
||||||
|
perms = nil |
||||||
|
authErr := errors.New("no auth passed yet") |
||||||
|
|
||||||
|
switch userAuthReq.Method { |
||||||
|
case "none": |
||||||
|
if config.NoClientAuth { |
||||||
|
s.user = "" |
||||||
|
authErr = nil |
||||||
|
} |
||||||
|
case "password": |
||||||
|
if config.PasswordCallback == nil { |
||||||
|
authErr = errors.New("ssh: password auth not configured") |
||||||
|
break |
||||||
|
} |
||||||
|
payload := userAuthReq.Payload |
||||||
|
if len(payload) < 1 || payload[0] != 0 { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
payload = payload[1:] |
||||||
|
password, payload, ok := parseString(payload) |
||||||
|
if !ok || len(payload) > 0 { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
|
||||||
|
perms, authErr = config.PasswordCallback(s, password) |
||||||
|
case "keyboard-interactive": |
||||||
|
if config.KeyboardInteractiveCallback == nil { |
||||||
|
authErr = errors.New("ssh: keyboard-interactive auth not configubred") |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
prompter := &sshClientKeyboardInteractive{s} |
||||||
|
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) |
||||||
|
case "publickey": |
||||||
|
if config.PublicKeyCallback == nil { |
||||||
|
authErr = errors.New("ssh: publickey auth not configured") |
||||||
|
break |
||||||
|
} |
||||||
|
payload := userAuthReq.Payload |
||||||
|
if len(payload) < 1 { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
isQuery := payload[0] == 0 |
||||||
|
payload = payload[1:] |
||||||
|
algoBytes, payload, ok := parseString(payload) |
||||||
|
if !ok { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
algo := string(algoBytes) |
||||||
|
if !isAcceptableAlgo(algo) { |
||||||
|
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
pubKeyData, payload, ok := parseString(payload) |
||||||
|
if !ok { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
|
||||||
|
pubKey, err := ParsePublicKey(pubKeyData) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
candidate, ok := cache.get(s.user, pubKeyData) |
||||||
|
if !ok { |
||||||
|
candidate.user = s.user |
||||||
|
candidate.pubKeyData = pubKeyData |
||||||
|
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) |
||||||
|
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { |
||||||
|
candidate.result = checkSourceAddress( |
||||||
|
s.RemoteAddr(), |
||||||
|
candidate.perms.CriticalOptions[sourceAddressCriticalOption]) |
||||||
|
} |
||||||
|
cache.add(candidate) |
||||||
|
} |
||||||
|
|
||||||
|
if isQuery { |
||||||
|
// The client can query if the given public key
|
||||||
|
// would be okay.
|
||||||
|
if len(payload) > 0 { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
|
||||||
|
if candidate.result == nil { |
||||||
|
okMsg := userAuthPubKeyOkMsg{ |
||||||
|
Algo: algo, |
||||||
|
PubKey: pubKeyData, |
||||||
|
} |
||||||
|
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
continue userAuthLoop |
||||||
|
} |
||||||
|
authErr = candidate.result |
||||||
|
} else { |
||||||
|
sig, payload, ok := parseSignature(payload) |
||||||
|
if !ok || len(payload) > 0 { |
||||||
|
return nil, parseError(msgUserAuthRequest) |
||||||
|
} |
||||||
|
// Ensure the public key algo and signature algo
|
||||||
|
// are supported. Compare the private key
|
||||||
|
// algorithm name that corresponds to algo with
|
||||||
|
// sig.Format. This is usually the same, but
|
||||||
|
// for certs, the names differ.
|
||||||
|
if !isAcceptableAlgo(sig.Format) { |
||||||
|
break |
||||||
|
} |
||||||
|
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) |
||||||
|
|
||||||
|
if err := pubKey.Verify(signedData, sig); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
authErr = candidate.result |
||||||
|
perms = candidate.perms |
||||||
|
} |
||||||
|
default: |
||||||
|
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) |
||||||
|
} |
||||||
|
|
||||||
|
if config.AuthLogCallback != nil { |
||||||
|
config.AuthLogCallback(s, userAuthReq.Method, authErr) |
||||||
|
} |
||||||
|
|
||||||
|
if authErr == nil { |
||||||
|
break userAuthLoop |
||||||
|
} |
||||||
|
|
||||||
|
var failureMsg userAuthFailureMsg |
||||||
|
if config.PasswordCallback != nil { |
||||||
|
failureMsg.Methods = append(failureMsg.Methods, "password") |
||||||
|
} |
||||||
|
if config.PublicKeyCallback != nil { |
||||||
|
failureMsg.Methods = append(failureMsg.Methods, "publickey") |
||||||
|
} |
||||||
|
if config.KeyboardInteractiveCallback != nil { |
||||||
|
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") |
||||||
|
} |
||||||
|
|
||||||
|
if len(failureMsg.Methods) == 0 { |
||||||
|
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||||
|
} |
||||||
|
|
||||||
|
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return perms, nil |
||||||
|
} |
||||||
|
|
||||||
|
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
||||||
|
// asking the client on the other side of a ServerConn.
|
||||||
|
type sshClientKeyboardInteractive struct { |
||||||
|
*connection |
||||||
|
} |
||||||
|
|
||||||
|
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { |
||||||
|
if len(questions) != len(echos) { |
||||||
|
return nil, errors.New("ssh: echos and questions must have equal length") |
||||||
|
} |
||||||
|
|
||||||
|
var prompts []byte |
||||||
|
for i := range questions { |
||||||
|
prompts = appendString(prompts, questions[i]) |
||||||
|
prompts = appendBool(prompts, echos[i]) |
||||||
|
} |
||||||
|
|
||||||
|
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ |
||||||
|
Instruction: instruction, |
||||||
|
NumPrompts: uint32(len(questions)), |
||||||
|
Prompts: prompts, |
||||||
|
})); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
packet, err := c.transport.readPacket() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if packet[0] != msgUserAuthInfoResponse { |
||||||
|
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) |
||||||
|
} |
||||||
|
packet = packet[1:] |
||||||
|
|
||||||
|
n, packet, ok := parseUint32(packet) |
||||||
|
if !ok || int(n) != len(questions) { |
||||||
|
return nil, parseError(msgUserAuthInfoResponse) |
||||||
|
} |
||||||
|
|
||||||
|
for i := uint32(0); i < n; i++ { |
||||||
|
ans, rest, ok := parseString(packet) |
||||||
|
if !ok { |
||||||
|
return nil, parseError(msgUserAuthInfoResponse) |
||||||
|
} |
||||||
|
|
||||||
|
answers = append(answers, string(ans)) |
||||||
|
packet = rest |
||||||
|
} |
||||||
|
if len(packet) != 0 { |
||||||
|
return nil, errors.New("ssh: junk at end of message") |
||||||
|
} |
||||||
|
|
||||||
|
return answers, nil |
||||||
|
} |
@ -0,0 +1,605 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
// Session implements an interactive session described in
|
||||||
|
// "RFC 4254, section 6".
|
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
type Signal string |
||||||
|
|
||||||
|
// POSIX signals as listed in RFC 4254 Section 6.10.
|
||||||
|
const ( |
||||||
|
SIGABRT Signal = "ABRT" |
||||||
|
SIGALRM Signal = "ALRM" |
||||||
|
SIGFPE Signal = "FPE" |
||||||
|
SIGHUP Signal = "HUP" |
||||||
|
SIGILL Signal = "ILL" |
||||||
|
SIGINT Signal = "INT" |
||||||
|
SIGKILL Signal = "KILL" |
||||||
|
SIGPIPE Signal = "PIPE" |
||||||
|
SIGQUIT Signal = "QUIT" |
||||||
|
SIGSEGV Signal = "SEGV" |
||||||
|
SIGTERM Signal = "TERM" |
||||||
|
SIGUSR1 Signal = "USR1" |
||||||
|
SIGUSR2 Signal = "USR2" |
||||||
|
) |
||||||
|
|
||||||
|
var signals = map[Signal]int{ |
||||||
|
SIGABRT: 6, |
||||||
|
SIGALRM: 14, |
||||||
|
SIGFPE: 8, |
||||||
|
SIGHUP: 1, |
||||||
|
SIGILL: 4, |
||||||
|
SIGINT: 2, |
||||||
|
SIGKILL: 9, |
||||||
|
SIGPIPE: 13, |
||||||
|
SIGQUIT: 3, |
||||||
|
SIGSEGV: 11, |
||||||
|
SIGTERM: 15, |
||||||
|
} |
||||||
|
|
||||||
|
type TerminalModes map[uint8]uint32 |
||||||
|
|
||||||
|
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
||||||
|
const ( |
||||||
|
tty_OP_END = 0 |
||||||
|
VINTR = 1 |
||||||
|
VQUIT = 2 |
||||||
|
VERASE = 3 |
||||||
|
VKILL = 4 |
||||||
|
VEOF = 5 |
||||||
|
VEOL = 6 |
||||||
|
VEOL2 = 7 |
||||||
|
VSTART = 8 |
||||||
|
VSTOP = 9 |
||||||
|
VSUSP = 10 |
||||||
|
VDSUSP = 11 |
||||||
|
VREPRINT = 12 |
||||||
|
VWERASE = 13 |
||||||
|
VLNEXT = 14 |
||||||
|
VFLUSH = 15 |
||||||
|
VSWTCH = 16 |
||||||
|
VSTATUS = 17 |
||||||
|
VDISCARD = 18 |
||||||
|
IGNPAR = 30 |
||||||
|
PARMRK = 31 |
||||||
|
INPCK = 32 |
||||||
|
ISTRIP = 33 |
||||||
|
INLCR = 34 |
||||||
|
IGNCR = 35 |
||||||
|
ICRNL = 36 |
||||||
|
IUCLC = 37 |
||||||
|
IXON = 38 |
||||||
|
IXANY = 39 |
||||||
|
IXOFF = 40 |
||||||
|
IMAXBEL = 41 |
||||||
|
ISIG = 50 |
||||||
|
ICANON = 51 |
||||||
|
XCASE = 52 |
||||||
|
ECHO = 53 |
||||||
|
ECHOE = 54 |
||||||
|
ECHOK = 55 |
||||||
|
ECHONL = 56 |
||||||
|
NOFLSH = 57 |
||||||
|
TOSTOP = 58 |
||||||
|
IEXTEN = 59 |
||||||
|
ECHOCTL = 60 |
||||||
|
ECHOKE = 61 |
||||||
|
PENDIN = 62 |
||||||
|
OPOST = 70 |
||||||
|
OLCUC = 71 |
||||||
|
ONLCR = 72 |
||||||
|
OCRNL = 73 |
||||||
|
ONOCR = 74 |
||||||
|
ONLRET = 75 |
||||||
|
CS7 = 90 |
||||||
|
CS8 = 91 |
||||||
|
PARENB = 92 |
||||||
|
PARODD = 93 |
||||||
|
TTY_OP_ISPEED = 128 |
||||||
|
TTY_OP_OSPEED = 129 |
||||||
|
) |
||||||
|
|
||||||
|
// A Session represents a connection to a remote command or shell.
|
||||||
|
type Session struct { |
||||||
|
// Stdin specifies the remote process's standard input.
|
||||||
|
// If Stdin is nil, the remote process reads from an empty
|
||||||
|
// bytes.Buffer.
|
||||||
|
Stdin io.Reader |
||||||
|
|
||||||
|
// Stdout and Stderr specify the remote process's standard
|
||||||
|
// output and error.
|
||||||
|
//
|
||||||
|
// If either is nil, Run connects the corresponding file
|
||||||
|
// descriptor to an instance of ioutil.Discard. There is a
|
||||||
|
// fixed amount of buffering that is shared for the two streams.
|
||||||
|
// If either blocks it may eventually cause the remote
|
||||||
|
// command to block.
|
||||||
|
Stdout io.Writer |
||||||
|
Stderr io.Writer |
||||||
|
|
||||||
|
ch Channel // the channel backing this session
|
||||||
|
started bool // true once Start, Run or Shell is invoked.
|
||||||
|
copyFuncs []func() error |
||||||
|
errors chan error // one send per copyFunc
|
||||||
|
|
||||||
|
// true if pipe method is active
|
||||||
|
stdinpipe, stdoutpipe, stderrpipe bool |
||||||
|
|
||||||
|
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
||||||
|
// and Stdin was specified by the user; it is the write end of
|
||||||
|
// a pipe connecting Session.Stdin to the stdin channel.
|
||||||
|
stdinPipeWriter io.WriteCloser |
||||||
|
|
||||||
|
exitStatus chan error |
||||||
|
} |
||||||
|
|
||||||
|
// SendRequest sends an out-of-band channel request on the SSH channel
|
||||||
|
// underlying the session.
|
||||||
|
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||||
|
return s.ch.SendRequest(name, wantReply, payload) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) Close() error { |
||||||
|
return s.ch.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 Section 6.4.
|
||||||
|
type setenvRequest struct { |
||||||
|
Name string |
||||||
|
Value string |
||||||
|
} |
||||||
|
|
||||||
|
// Setenv sets an environment variable that will be applied to any
|
||||||
|
// command executed by Shell or Run.
|
||||||
|
func (s *Session) Setenv(name, value string) error { |
||||||
|
msg := setenvRequest{ |
||||||
|
Name: name, |
||||||
|
Value: value, |
||||||
|
} |
||||||
|
ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) |
||||||
|
if err == nil && !ok { |
||||||
|
err = errors.New("ssh: setenv failed") |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 Section 6.2.
|
||||||
|
type ptyRequestMsg struct { |
||||||
|
Term string |
||||||
|
Columns uint32 |
||||||
|
Rows uint32 |
||||||
|
Width uint32 |
||||||
|
Height uint32 |
||||||
|
Modelist string |
||||||
|
} |
||||||
|
|
||||||
|
// RequestPty requests the association of a pty with the session on the remote host.
|
||||||
|
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { |
||||||
|
var tm []byte |
||||||
|
for k, v := range termmodes { |
||||||
|
kv := struct { |
||||||
|
Key byte |
||||||
|
Val uint32 |
||||||
|
}{k, v} |
||||||
|
|
||||||
|
tm = append(tm, Marshal(&kv)...) |
||||||
|
} |
||||||
|
tm = append(tm, tty_OP_END) |
||||||
|
req := ptyRequestMsg{ |
||||||
|
Term: term, |
||||||
|
Columns: uint32(w), |
||||||
|
Rows: uint32(h), |
||||||
|
Width: uint32(w * 8), |
||||||
|
Height: uint32(h * 8), |
||||||
|
Modelist: string(tm), |
||||||
|
} |
||||||
|
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) |
||||||
|
if err == nil && !ok { |
||||||
|
err = errors.New("ssh: pty-req failed") |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 Section 6.5.
|
||||||
|
type subsystemRequestMsg struct { |
||||||
|
Subsystem string |
||||||
|
} |
||||||
|
|
||||||
|
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
||||||
|
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
||||||
|
func (s *Session) RequestSubsystem(subsystem string) error { |
||||||
|
msg := subsystemRequestMsg{ |
||||||
|
Subsystem: subsystem, |
||||||
|
} |
||||||
|
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) |
||||||
|
if err == nil && !ok { |
||||||
|
err = errors.New("ssh: subsystem request failed") |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 Section 6.9.
|
||||||
|
type signalMsg struct { |
||||||
|
Signal string |
||||||
|
} |
||||||
|
|
||||||
|
// Signal sends the given signal to the remote process.
|
||||||
|
// sig is one of the SIG* constants.
|
||||||
|
func (s *Session) Signal(sig Signal) error { |
||||||
|
msg := signalMsg{ |
||||||
|
Signal: string(sig), |
||||||
|
} |
||||||
|
|
||||||
|
_, err := s.ch.SendRequest("signal", false, Marshal(&msg)) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 Section 6.5.
|
||||||
|
type execMsg struct { |
||||||
|
Command string |
||||||
|
} |
||||||
|
|
||||||
|
// Start runs cmd on the remote host. Typically, the remote
|
||||||
|
// server passes cmd to the shell for interpretation.
|
||||||
|
// A Session only accepts one call to Run, Start or Shell.
|
||||||
|
func (s *Session) Start(cmd string) error { |
||||||
|
if s.started { |
||||||
|
return errors.New("ssh: session already started") |
||||||
|
} |
||||||
|
req := execMsg{ |
||||||
|
Command: cmd, |
||||||
|
} |
||||||
|
|
||||||
|
ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) |
||||||
|
if err == nil && !ok { |
||||||
|
err = fmt.Errorf("ssh: command %v failed", cmd) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return s.start() |
||||||
|
} |
||||||
|
|
||||||
|
// Run runs cmd on the remote host. Typically, the remote
|
||||||
|
// server passes cmd to the shell for interpretation.
|
||||||
|
// A Session only accepts one call to Run, Start, Shell, Output,
|
||||||
|
// or CombinedOutput.
|
||||||
|
//
|
||||||
|
// The returned error is nil if the command runs, has no problems
|
||||||
|
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||||
|
// status.
|
||||||
|
//
|
||||||
|
// If the command fails to run or doesn't complete successfully, the
|
||||||
|
// error is of type *ExitError. Other error types may be
|
||||||
|
// returned for I/O problems.
|
||||||
|
func (s *Session) Run(cmd string) error { |
||||||
|
err := s.Start(cmd) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return s.Wait() |
||||||
|
} |
||||||
|
|
||||||
|
// Output runs cmd on the remote host and returns its standard output.
|
||||||
|
func (s *Session) Output(cmd string) ([]byte, error) { |
||||||
|
if s.Stdout != nil { |
||||||
|
return nil, errors.New("ssh: Stdout already set") |
||||||
|
} |
||||||
|
var b bytes.Buffer |
||||||
|
s.Stdout = &b |
||||||
|
err := s.Run(cmd) |
||||||
|
return b.Bytes(), err |
||||||
|
} |
||||||
|
|
||||||
|
type singleWriter struct { |
||||||
|
b bytes.Buffer |
||||||
|
mu sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
func (w *singleWriter) Write(p []byte) (int, error) { |
||||||
|
w.mu.Lock() |
||||||
|
defer w.mu.Unlock() |
||||||
|
return w.b.Write(p) |
||||||
|
} |
||||||
|
|
||||||
|
// CombinedOutput runs cmd on the remote host and returns its combined
|
||||||
|
// standard output and standard error.
|
||||||
|
func (s *Session) CombinedOutput(cmd string) ([]byte, error) { |
||||||
|
if s.Stdout != nil { |
||||||
|
return nil, errors.New("ssh: Stdout already set") |
||||||
|
} |
||||||
|
if s.Stderr != nil { |
||||||
|
return nil, errors.New("ssh: Stderr already set") |
||||||
|
} |
||||||
|
var b singleWriter |
||||||
|
s.Stdout = &b |
||||||
|
s.Stderr = &b |
||||||
|
err := s.Run(cmd) |
||||||
|
return b.b.Bytes(), err |
||||||
|
} |
||||||
|
|
||||||
|
// Shell starts a login shell on the remote host. A Session only
|
||||||
|
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
||||||
|
func (s *Session) Shell() error { |
||||||
|
if s.started { |
||||||
|
return errors.New("ssh: session already started") |
||||||
|
} |
||||||
|
|
||||||
|
ok, err := s.ch.SendRequest("shell", true, nil) |
||||||
|
if err == nil && !ok { |
||||||
|
return fmt.Errorf("ssh: cound not start shell") |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return s.start() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) start() error { |
||||||
|
s.started = true |
||||||
|
|
||||||
|
type F func(*Session) |
||||||
|
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { |
||||||
|
setupFd(s) |
||||||
|
} |
||||||
|
|
||||||
|
s.errors = make(chan error, len(s.copyFuncs)) |
||||||
|
for _, fn := range s.copyFuncs { |
||||||
|
go func(fn func() error) { |
||||||
|
s.errors <- fn() |
||||||
|
}(fn) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Wait waits for the remote command to exit.
|
||||||
|
//
|
||||||
|
// The returned error is nil if the command runs, has no problems
|
||||||
|
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||||
|
// status.
|
||||||
|
//
|
||||||
|
// If the command fails to run or doesn't complete successfully, the
|
||||||
|
// error is of type *ExitError. Other error types may be
|
||||||
|
// returned for I/O problems.
|
||||||
|
func (s *Session) Wait() error { |
||||||
|
if !s.started { |
||||||
|
return errors.New("ssh: session not started") |
||||||
|
} |
||||||
|
waitErr := <-s.exitStatus |
||||||
|
|
||||||
|
if s.stdinPipeWriter != nil { |
||||||
|
s.stdinPipeWriter.Close() |
||||||
|
} |
||||||
|
var copyError error |
||||||
|
for _ = range s.copyFuncs { |
||||||
|
if err := <-s.errors; err != nil && copyError == nil { |
||||||
|
copyError = err |
||||||
|
} |
||||||
|
} |
||||||
|
if waitErr != nil { |
||||||
|
return waitErr |
||||||
|
} |
||||||
|
return copyError |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) wait(reqs <-chan *Request) error { |
||||||
|
wm := Waitmsg{status: -1} |
||||||
|
// Wait for msg channel to be closed before returning.
|
||||||
|
for msg := range reqs { |
||||||
|
switch msg.Type { |
||||||
|
case "exit-status": |
||||||
|
d := msg.Payload |
||||||
|
wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) |
||||||
|
case "exit-signal": |
||||||
|
var sigval struct { |
||||||
|
Signal string |
||||||
|
CoreDumped bool |
||||||
|
Error string |
||||||
|
Lang string |
||||||
|
} |
||||||
|
if err := Unmarshal(msg.Payload, &sigval); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// Must sanitize strings?
|
||||||
|
wm.signal = sigval.Signal |
||||||
|
wm.msg = sigval.Error |
||||||
|
wm.lang = sigval.Lang |
||||||
|
default: |
||||||
|
// This handles keepalives and matches
|
||||||
|
// OpenSSH's behaviour.
|
||||||
|
if msg.WantReply { |
||||||
|
msg.Reply(false, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
if wm.status == 0 { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if wm.status == -1 { |
||||||
|
// exit-status was never sent from server
|
||||||
|
if wm.signal == "" { |
||||||
|
return errors.New("wait: remote command exited without exit status or exit signal") |
||||||
|
} |
||||||
|
wm.status = 128 |
||||||
|
if _, ok := signals[Signal(wm.signal)]; ok { |
||||||
|
wm.status += signals[Signal(wm.signal)] |
||||||
|
} |
||||||
|
} |
||||||
|
return &ExitError{wm} |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) stdin() { |
||||||
|
if s.stdinpipe { |
||||||
|
return |
||||||
|
} |
||||||
|
var stdin io.Reader |
||||||
|
if s.Stdin == nil { |
||||||
|
stdin = new(bytes.Buffer) |
||||||
|
} else { |
||||||
|
r, w := io.Pipe() |
||||||
|
go func() { |
||||||
|
_, err := io.Copy(w, s.Stdin) |
||||||
|
w.CloseWithError(err) |
||||||
|
}() |
||||||
|
stdin, s.stdinPipeWriter = r, w |
||||||
|
} |
||||||
|
s.copyFuncs = append(s.copyFuncs, func() error { |
||||||
|
_, err := io.Copy(s.ch, stdin) |
||||||
|
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { |
||||||
|
err = err1 |
||||||
|
} |
||||||
|
return err |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) stdout() { |
||||||
|
if s.stdoutpipe { |
||||||
|
return |
||||||
|
} |
||||||
|
if s.Stdout == nil { |
||||||
|
s.Stdout = ioutil.Discard |
||||||
|
} |
||||||
|
s.copyFuncs = append(s.copyFuncs, func() error { |
||||||
|
_, err := io.Copy(s.Stdout, s.ch) |
||||||
|
return err |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Session) stderr() { |
||||||
|
if s.stderrpipe { |
||||||
|
return |
||||||
|
} |
||||||
|
if s.Stderr == nil { |
||||||
|
s.Stderr = ioutil.Discard |
||||||
|
} |
||||||
|
s.copyFuncs = append(s.copyFuncs, func() error { |
||||||
|
_, err := io.Copy(s.Stderr, s.ch.Stderr()) |
||||||
|
return err |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// sessionStdin reroutes Close to CloseWrite.
|
||||||
|
type sessionStdin struct { |
||||||
|
io.Writer |
||||||
|
ch Channel |
||||||
|
} |
||||||
|
|
||||||
|
func (s *sessionStdin) Close() error { |
||||||
|
return s.ch.CloseWrite() |
||||||
|
} |
||||||
|
|
||||||
|
// StdinPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard input when the command starts.
|
||||||
|
func (s *Session) StdinPipe() (io.WriteCloser, error) { |
||||||
|
if s.Stdin != nil { |
||||||
|
return nil, errors.New("ssh: Stdin already set") |
||||||
|
} |
||||||
|
if s.started { |
||||||
|
return nil, errors.New("ssh: StdinPipe after process started") |
||||||
|
} |
||||||
|
s.stdinpipe = true |
||||||
|
return &sessionStdin{s.ch, s.ch}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// StdoutPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard output when the command starts.
|
||||||
|
// There is a fixed amount of buffering that is shared between
|
||||||
|
// stdout and stderr streams. If the StdoutPipe reader is
|
||||||
|
// not serviced fast enough it may eventually cause the
|
||||||
|
// remote command to block.
|
||||||
|
func (s *Session) StdoutPipe() (io.Reader, error) { |
||||||
|
if s.Stdout != nil { |
||||||
|
return nil, errors.New("ssh: Stdout already set") |
||||||
|
} |
||||||
|
if s.started { |
||||||
|
return nil, errors.New("ssh: StdoutPipe after process started") |
||||||
|
} |
||||||
|
s.stdoutpipe = true |
||||||
|
return s.ch, nil |
||||||
|
} |
||||||
|
|
||||||
|
// StderrPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard error when the command starts.
|
||||||
|
// There is a fixed amount of buffering that is shared between
|
||||||
|
// stdout and stderr streams. If the StderrPipe reader is
|
||||||
|
// not serviced fast enough it may eventually cause the
|
||||||
|
// remote command to block.
|
||||||
|
func (s *Session) StderrPipe() (io.Reader, error) { |
||||||
|
if s.Stderr != nil { |
||||||
|
return nil, errors.New("ssh: Stderr already set") |
||||||
|
} |
||||||
|
if s.started { |
||||||
|
return nil, errors.New("ssh: StderrPipe after process started") |
||||||
|
} |
||||||
|
s.stderrpipe = true |
||||||
|
return s.ch.Stderr(), nil |
||||||
|
} |
||||||
|
|
||||||
|
// newSession returns a new interactive session on the remote host.
|
||||||
|
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { |
||||||
|
s := &Session{ |
||||||
|
ch: ch, |
||||||
|
} |
||||||
|
s.exitStatus = make(chan error, 1) |
||||||
|
go func() { |
||||||
|
s.exitStatus <- s.wait(reqs) |
||||||
|
}() |
||||||
|
|
||||||
|
return s, nil |
||||||
|
} |
||||||
|
|
||||||
|
// An ExitError reports unsuccessful completion of a remote command.
|
||||||
|
type ExitError struct { |
||||||
|
Waitmsg |
||||||
|
} |
||||||
|
|
||||||
|
func (e *ExitError) Error() string { |
||||||
|
return e.Waitmsg.String() |
||||||
|
} |
||||||
|
|
||||||
|
// Waitmsg stores the information about an exited remote command
|
||||||
|
// as reported by Wait.
|
||||||
|
type Waitmsg struct { |
||||||
|
status int |
||||||
|
signal string |
||||||
|
msg string |
||||||
|
lang string |
||||||
|
} |
||||||
|
|
||||||
|
// ExitStatus returns the exit status of the remote command.
|
||||||
|
func (w Waitmsg) ExitStatus() int { |
||||||
|
return w.status |
||||||
|
} |
||||||
|
|
||||||
|
// Signal returns the exit signal of the remote command if
|
||||||
|
// it was terminated violently.
|
||||||
|
func (w Waitmsg) Signal() string { |
||||||
|
return w.signal |
||||||
|
} |
||||||
|
|
||||||
|
// Msg returns the exit message given by the remote command
|
||||||
|
func (w Waitmsg) Msg() string { |
||||||
|
return w.msg |
||||||
|
} |
||||||
|
|
||||||
|
// Lang returns the language tag. See RFC 3066
|
||||||
|
func (w Waitmsg) Lang() string { |
||||||
|
return w.lang |
||||||
|
} |
||||||
|
|
||||||
|
func (w Waitmsg) String() string { |
||||||
|
return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal) |
||||||
|
} |
@ -0,0 +1,774 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
// Session tests.
|
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
crypto_rand "crypto/rand" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
"math/rand" |
||||||
|
"net" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/ssh/terminal" |
||||||
|
) |
||||||
|
|
||||||
|
type serverType func(Channel, <-chan *Request, *testing.T) |
||||||
|
|
||||||
|
// dial constructs a new test server and returns a *ClientConn.
|
||||||
|
func dial(handler serverType, t *testing.T) *Client { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
defer c1.Close() |
||||||
|
conf := ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
conf.AddHostKey(testSigners["rsa"]) |
||||||
|
|
||||||
|
_, chans, reqs, err := NewServerConn(c1, &conf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to handshake: %v", err) |
||||||
|
} |
||||||
|
go DiscardRequests(reqs) |
||||||
|
|
||||||
|
for newCh := range chans { |
||||||
|
if newCh.ChannelType() != "session" { |
||||||
|
newCh.Reject(UnknownChannelType, "unknown channel type") |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
ch, inReqs, err := newCh.Accept() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("Accept: %v", err) |
||||||
|
continue |
||||||
|
} |
||||||
|
go func() { |
||||||
|
handler(ch, inReqs, t) |
||||||
|
}() |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
config := &ClientConfig{ |
||||||
|
User: "testuser", |
||||||
|
} |
||||||
|
|
||||||
|
conn, chans, reqs, err := NewClientConn(c2, "", config) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to dial remote side: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
return NewClient(conn, chans, reqs) |
||||||
|
} |
||||||
|
|
||||||
|
// Test a simple string is returned to session.Stdout.
|
||||||
|
func TestSessionShell(t *testing.T) { |
||||||
|
conn := dial(shellHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
stdout := new(bytes.Buffer) |
||||||
|
session.Stdout = stdout |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %s", err) |
||||||
|
} |
||||||
|
if err := session.Wait(); err != nil { |
||||||
|
t.Fatalf("Remote command did not exit cleanly: %v", err) |
||||||
|
} |
||||||
|
actual := stdout.String() |
||||||
|
if actual != "golang" { |
||||||
|
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
|
||||||
|
|
||||||
|
// Test a simple string is returned via StdoutPipe.
|
||||||
|
func TestSessionStdoutPipe(t *testing.T) { |
||||||
|
conn := dial(shellHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
stdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request StdoutPipe(): %v", err) |
||||||
|
} |
||||||
|
var buf bytes.Buffer |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
done := make(chan bool, 1) |
||||||
|
go func() { |
||||||
|
if _, err := io.Copy(&buf, stdout); err != nil { |
||||||
|
t.Errorf("Copy of stdout failed: %v", err) |
||||||
|
} |
||||||
|
done <- true |
||||||
|
}() |
||||||
|
if err := session.Wait(); err != nil { |
||||||
|
t.Fatalf("Remote command did not exit cleanly: %v", err) |
||||||
|
} |
||||||
|
<-done |
||||||
|
actual := buf.String() |
||||||
|
if actual != "golang" { |
||||||
|
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test that a simple string is returned via the Output helper,
|
||||||
|
// and that stderr is discarded.
|
||||||
|
func TestSessionOutput(t *testing.T) { |
||||||
|
conn := dial(fixedOutputHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
|
||||||
|
if err != nil { |
||||||
|
t.Error("Remote command did not exit cleanly:", err) |
||||||
|
} |
||||||
|
w := "this-is-stdout." |
||||||
|
g := string(buf) |
||||||
|
if g != w { |
||||||
|
t.Error("Remote command did not return expected string:") |
||||||
|
t.Logf("want %q", w) |
||||||
|
t.Logf("got %q", g) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test that both stdout and stderr are returned
|
||||||
|
// via the CombinedOutput helper.
|
||||||
|
func TestSessionCombinedOutput(t *testing.T) { |
||||||
|
conn := dial(fixedOutputHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
|
||||||
|
if err != nil { |
||||||
|
t.Error("Remote command did not exit cleanly:", err) |
||||||
|
} |
||||||
|
const stdout = "this-is-stdout." |
||||||
|
const stderr = "this-is-stderr." |
||||||
|
g := string(buf) |
||||||
|
if g != stdout+stderr && g != stderr+stdout { |
||||||
|
t.Error("Remote command did not return expected string:") |
||||||
|
t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) |
||||||
|
t.Logf("got %q", g) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test non-0 exit status is returned correctly.
|
||||||
|
func TestExitStatusNonZero(t *testing.T) { |
||||||
|
conn := dial(exitStatusNonZeroHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("expected command to fail but it didn't") |
||||||
|
} |
||||||
|
e, ok := err.(*ExitError) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("expected *ExitError but got %T", err) |
||||||
|
} |
||||||
|
if e.ExitStatus() != 15 { |
||||||
|
t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test 0 exit status is returned correctly.
|
||||||
|
func TestExitStatusZero(t *testing.T) { |
||||||
|
conn := dial(exitStatusZeroHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("expected nil but got %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test exit signal and status are both returned correctly.
|
||||||
|
func TestExitSignalAndStatus(t *testing.T) { |
||||||
|
conn := dial(exitSignalAndStatusHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("expected command to fail but it didn't") |
||||||
|
} |
||||||
|
e, ok := err.(*ExitError) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("expected *ExitError but got %T", err) |
||||||
|
} |
||||||
|
if e.Signal() != "TERM" || e.ExitStatus() != 15 { |
||||||
|
t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test exit signal and status are both returned correctly.
|
||||||
|
func TestKnownExitSignalOnly(t *testing.T) { |
||||||
|
conn := dial(exitSignalHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("expected command to fail but it didn't") |
||||||
|
} |
||||||
|
e, ok := err.(*ExitError) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("expected *ExitError but got %T", err) |
||||||
|
} |
||||||
|
if e.Signal() != "TERM" || e.ExitStatus() != 143 { |
||||||
|
t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test exit signal and status are both returned correctly.
|
||||||
|
func TestUnknownExitSignal(t *testing.T) { |
||||||
|
conn := dial(exitSignalUnknownHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("expected command to fail but it didn't") |
||||||
|
} |
||||||
|
e, ok := err.(*ExitError) |
||||||
|
if !ok { |
||||||
|
t.Fatalf("expected *ExitError but got %T", err) |
||||||
|
} |
||||||
|
if e.Signal() != "SYS" || e.ExitStatus() != 128 { |
||||||
|
t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Test WaitMsg is not returned if the channel closes abruptly.
|
||||||
|
func TestExitWithoutStatusOrSignal(t *testing.T) { |
||||||
|
conn := dial(exitWithoutSignalOrStatus, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Unable to request new session: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("expected command to fail but it didn't") |
||||||
|
} |
||||||
|
_, ok := err.(*ExitError) |
||||||
|
if ok { |
||||||
|
// you can't actually test for errors.errorString
|
||||||
|
// because it's not exported.
|
||||||
|
t.Fatalf("expected *errorString but got %T", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// windowTestBytes is the number of bytes that we'll send to the SSH server.
|
||||||
|
const windowTestBytes = 16000 * 200 |
||||||
|
|
||||||
|
// TestServerWindow writes random data to the server. The server is expected to echo
|
||||||
|
// the same data back, which is compared against the original.
|
||||||
|
func TestServerWindow(t *testing.T) { |
||||||
|
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
||||||
|
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) |
||||||
|
origBytes := origBuf.Bytes() |
||||||
|
|
||||||
|
conn := dial(echoHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
result := make(chan []byte) |
||||||
|
|
||||||
|
go func() { |
||||||
|
defer close(result) |
||||||
|
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) |
||||||
|
serverStdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("StdoutPipe failed: %v", err) |
||||||
|
return |
||||||
|
} |
||||||
|
n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) |
||||||
|
if err != nil && err != io.EOF { |
||||||
|
t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) |
||||||
|
} |
||||||
|
result <- echoedBuf.Bytes() |
||||||
|
}() |
||||||
|
|
||||||
|
serverStdin, err := session.StdinPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("StdinPipe failed: %v", err) |
||||||
|
} |
||||||
|
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("failed to copy origBuf to serverStdin: %v", err) |
||||||
|
} |
||||||
|
if written != windowTestBytes { |
||||||
|
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) |
||||||
|
} |
||||||
|
|
||||||
|
echoedBytes := <-result |
||||||
|
|
||||||
|
if !bytes.Equal(origBytes, echoedBytes) { |
||||||
|
t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Verify the client can handle a keepalive packet from the server.
|
||||||
|
func TestClientHandlesKeepalives(t *testing.T) { |
||||||
|
conn := dial(channelKeepaliveSender, t) |
||||||
|
defer conn.Close() |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
if err := session.Shell(); err != nil { |
||||||
|
t.Fatalf("Unable to execute command: %v", err) |
||||||
|
} |
||||||
|
err = session.Wait() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("expected nil but got: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type exitStatusMsg struct { |
||||||
|
Status uint32 |
||||||
|
} |
||||||
|
|
||||||
|
type exitSignalMsg struct { |
||||||
|
Signal string |
||||||
|
CoreDumped bool |
||||||
|
Errmsg string |
||||||
|
Lang string |
||||||
|
} |
||||||
|
|
||||||
|
func handleTerminalRequests(in <-chan *Request) { |
||||||
|
for req := range in { |
||||||
|
ok := false |
||||||
|
switch req.Type { |
||||||
|
case "shell": |
||||||
|
ok = true |
||||||
|
if len(req.Payload) > 0 { |
||||||
|
// We don't accept any commands, only the default shell.
|
||||||
|
ok = false |
||||||
|
} |
||||||
|
case "env": |
||||||
|
ok = true |
||||||
|
} |
||||||
|
req.Reply(ok, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { |
||||||
|
term := terminal.NewTerminal(ch, prompt) |
||||||
|
go handleTerminalRequests(in) |
||||||
|
return term |
||||||
|
} |
||||||
|
|
||||||
|
func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
// this string is returned to stdout
|
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
sendStatus(0, ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
sendStatus(15, ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
sendStatus(15, ch, t) |
||||||
|
sendSignal("TERM", ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
sendSignal("TERM", ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
sendSignal("SYS", ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
} |
||||||
|
|
||||||
|
func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
// this string is returned to stdout
|
||||||
|
shell := newServerShell(ch, in, "golang") |
||||||
|
readLine(shell, t) |
||||||
|
sendStatus(0, ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
// Ignores the command, writes fixed strings to stderr and stdout.
|
||||||
|
// Strings are "this-is-stdout." and "this-is-stderr.".
|
||||||
|
func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
_, err := ch.Read(nil) |
||||||
|
|
||||||
|
req, ok := <-in |
||||||
|
if !ok { |
||||||
|
t.Fatalf("error: expected channel request, got: %#v", err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// ignore request, always send some text
|
||||||
|
req.Reply(true, nil) |
||||||
|
|
||||||
|
_, err = io.WriteString(ch, "this-is-stdout.") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("error writing on server: %v", err) |
||||||
|
} |
||||||
|
_, err = io.WriteString(ch.Stderr(), "this-is-stderr.") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("error writing on server: %v", err) |
||||||
|
} |
||||||
|
sendStatus(0, ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func readLine(shell *terminal.Terminal, t *testing.T) { |
||||||
|
if _, err := shell.ReadLine(); err != nil && err != io.EOF { |
||||||
|
t.Errorf("unable to read line: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func sendStatus(status uint32, ch Channel, t *testing.T) { |
||||||
|
msg := exitStatusMsg{ |
||||||
|
Status: status, |
||||||
|
} |
||||||
|
if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { |
||||||
|
t.Errorf("unable to send status: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func sendSignal(signal string, ch Channel, t *testing.T) { |
||||||
|
sig := exitSignalMsg{ |
||||||
|
Signal: signal, |
||||||
|
CoreDumped: false, |
||||||
|
Errmsg: "Process terminated", |
||||||
|
Lang: "en-GB-oed", |
||||||
|
} |
||||||
|
if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { |
||||||
|
t.Errorf("unable to send signal: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func discardHandler(ch Channel, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
io.Copy(ioutil.Discard, ch) |
||||||
|
} |
||||||
|
|
||||||
|
func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { |
||||||
|
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
|
||||||
|
// buffer size to exercise more code paths.
|
||||||
|
func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { |
||||||
|
var ( |
||||||
|
buf = make([]byte, 32*1024) |
||||||
|
written int |
||||||
|
remaining = n |
||||||
|
) |
||||||
|
for remaining > 0 { |
||||||
|
l := rand.Intn(1 << 15) |
||||||
|
if remaining < l { |
||||||
|
l = remaining |
||||||
|
} |
||||||
|
nr, er := src.Read(buf[:l]) |
||||||
|
nw, ew := dst.Write(buf[:nr]) |
||||||
|
remaining -= nw |
||||||
|
written += nw |
||||||
|
if ew != nil { |
||||||
|
return written, ew |
||||||
|
} |
||||||
|
if nr != nw { |
||||||
|
return written, io.ErrShortWrite |
||||||
|
} |
||||||
|
if er != nil && er != io.EOF { |
||||||
|
return written, er |
||||||
|
} |
||||||
|
} |
||||||
|
return written, nil |
||||||
|
} |
||||||
|
|
||||||
|
func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
shell := newServerShell(ch, in, "> ") |
||||||
|
readLine(shell, t) |
||||||
|
if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { |
||||||
|
t.Errorf("unable to send channel keepalive request: %v", err) |
||||||
|
} |
||||||
|
sendStatus(0, ch, t) |
||||||
|
} |
||||||
|
|
||||||
|
func TestClientWriteEOF(t *testing.T) { |
||||||
|
conn := dial(simpleEchoHandler, t) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
stdin, err := session.StdinPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("StdinPipe failed: %v", err) |
||||||
|
} |
||||||
|
stdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("StdoutPipe failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
data := []byte(`0000`) |
||||||
|
_, err = stdin.Write(data) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Write failed: %v", err) |
||||||
|
} |
||||||
|
stdin.Close() |
||||||
|
|
||||||
|
res, err := ioutil.ReadAll(stdout) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Read failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if !bytes.Equal(data, res) { |
||||||
|
t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { |
||||||
|
defer ch.Close() |
||||||
|
data, err := ioutil.ReadAll(ch) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("handler read error: %v", err) |
||||||
|
} |
||||||
|
_, err = ch.Write(data) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("handler write error: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestSessionID(t *testing.T) { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
serverID := make(chan []byte, 1) |
||||||
|
clientID := make(chan []byte, 1) |
||||||
|
|
||||||
|
serverConf := &ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||||
|
clientConf := &ClientConfig{ |
||||||
|
User: "user", |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
conn, chans, reqs, err := NewServerConn(c1, serverConf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("server handshake: %v", err) |
||||||
|
} |
||||||
|
serverID <- conn.SessionID() |
||||||
|
go DiscardRequests(reqs) |
||||||
|
for ch := range chans { |
||||||
|
ch.Reject(Prohibited, "") |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
go func() { |
||||||
|
conn, chans, reqs, err := NewClientConn(c2, "", clientConf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("client handshake: %v", err) |
||||||
|
} |
||||||
|
clientID <- conn.SessionID() |
||||||
|
go DiscardRequests(reqs) |
||||||
|
for ch := range chans { |
||||||
|
ch.Reject(Prohibited, "") |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
s := <-serverID |
||||||
|
c := <-clientID |
||||||
|
if bytes.Compare(s, c) != 0 { |
||||||
|
t.Errorf("server session ID (%x) != client session ID (%x)", s, c) |
||||||
|
} else if len(s) == 0 { |
||||||
|
t.Errorf("client and server SessionID were empty.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type noReadConn struct { |
||||||
|
readSeen bool |
||||||
|
net.Conn |
||||||
|
} |
||||||
|
|
||||||
|
func (c *noReadConn) Close() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *noReadConn) Read(b []byte) (int, error) { |
||||||
|
c.readSeen = true |
||||||
|
return 0, errors.New("noReadConn error") |
||||||
|
} |
||||||
|
|
||||||
|
func TestInvalidServerConfiguration(t *testing.T) { |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
serveConn := noReadConn{Conn: c1} |
||||||
|
serverConf := &ServerConfig{} |
||||||
|
|
||||||
|
NewServerConn(&serveConn, serverConf) |
||||||
|
if serveConn.readSeen { |
||||||
|
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") |
||||||
|
} |
||||||
|
|
||||||
|
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||||
|
|
||||||
|
NewServerConn(&serveConn, serverConf) |
||||||
|
if serveConn.readSeen { |
||||||
|
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHostKeyAlgorithms(t *testing.T) { |
||||||
|
serverConf := &ServerConfig{ |
||||||
|
NoClientAuth: true, |
||||||
|
} |
||||||
|
serverConf.AddHostKey(testSigners["rsa"]) |
||||||
|
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||||
|
|
||||||
|
connect := func(clientConf *ClientConfig, want string) { |
||||||
|
var alg string |
||||||
|
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { |
||||||
|
alg = key.Type() |
||||||
|
return nil |
||||||
|
} |
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
go NewServerConn(c1, serverConf) |
||||||
|
_, _, _, err = NewClientConn(c2, "", clientConf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewClientConn: %v", err) |
||||||
|
} |
||||||
|
if alg != want { |
||||||
|
t.Errorf("selected key algorithm %s, want %s", alg, want) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// By default, we get the preferred algorithm, which is ECDSA 256.
|
||||||
|
|
||||||
|
clientConf := &ClientConfig{} |
||||||
|
connect(clientConf, KeyAlgoECDSA256) |
||||||
|
|
||||||
|
// Client asks for RSA explicitly.
|
||||||
|
clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} |
||||||
|
connect(clientConf, KeyAlgoRSA) |
||||||
|
|
||||||
|
c1, c2, err := netPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("netPipe: %v", err) |
||||||
|
} |
||||||
|
defer c1.Close() |
||||||
|
defer c2.Close() |
||||||
|
|
||||||
|
go NewServerConn(c1, serverConf) |
||||||
|
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} |
||||||
|
_, _, _, err = NewClientConn(c2, "", clientConf) |
||||||
|
if err == nil { |
||||||
|
t.Fatal("succeeded connecting with unknown hostkey algorithm") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,407 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"math/rand" |
||||||
|
"net" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
// Listen requests the remote peer open a listening socket on
|
||||||
|
// addr. Incoming connections will be available by calling Accept on
|
||||||
|
// the returned net.Listener. The listener must be serviced, or the
|
||||||
|
// SSH connection may hang.
|
||||||
|
func (c *Client) Listen(n, addr string) (net.Listener, error) { |
||||||
|
laddr, err := net.ResolveTCPAddr(n, addr) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return c.ListenTCP(laddr) |
||||||
|
} |
||||||
|
|
||||||
|
// Automatic port allocation is broken with OpenSSH before 6.0. See
|
||||||
|
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
|
||||||
|
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
|
||||||
|
// rather than the actual port number. This means you can never open
|
||||||
|
// two different listeners with auto allocated ports. We work around
|
||||||
|
// this by trying explicit ports until we succeed.
|
||||||
|
|
||||||
|
const openSSHPrefix = "OpenSSH_" |
||||||
|
|
||||||
|
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) |
||||||
|
|
||||||
|
// isBrokenOpenSSHVersion returns true if the given version string
|
||||||
|
// specifies a version of OpenSSH that is known to have a bug in port
|
||||||
|
// forwarding.
|
||||||
|
func isBrokenOpenSSHVersion(versionStr string) bool { |
||||||
|
i := strings.Index(versionStr, openSSHPrefix) |
||||||
|
if i < 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
i += len(openSSHPrefix) |
||||||
|
j := i |
||||||
|
for ; j < len(versionStr); j++ { |
||||||
|
if versionStr[j] < '0' || versionStr[j] > '9' { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
version, _ := strconv.Atoi(versionStr[i:j]) |
||||||
|
return version < 6 |
||||||
|
} |
||||||
|
|
||||||
|
// autoPortListenWorkaround simulates automatic port allocation by
|
||||||
|
// trying random ports repeatedly.
|
||||||
|
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { |
||||||
|
var sshListener net.Listener |
||||||
|
var err error |
||||||
|
const tries = 10 |
||||||
|
for i := 0; i < tries; i++ { |
||||||
|
addr := *laddr |
||||||
|
addr.Port = 1024 + portRandomizer.Intn(60000) |
||||||
|
sshListener, err = c.ListenTCP(&addr) |
||||||
|
if err == nil { |
||||||
|
laddr.Port = addr.Port |
||||||
|
return sshListener, err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 7.1
|
||||||
|
type channelForwardMsg struct { |
||||||
|
addr string |
||||||
|
rport uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// ListenTCP requests the remote peer open a listening socket
|
||||||
|
// on laddr. Incoming connections will be available by calling
|
||||||
|
// Accept on the returned net.Listener.
|
||||||
|
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { |
||||||
|
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { |
||||||
|
return c.autoPortListenWorkaround(laddr) |
||||||
|
} |
||||||
|
|
||||||
|
m := channelForwardMsg{ |
||||||
|
laddr.IP.String(), |
||||||
|
uint32(laddr.Port), |
||||||
|
} |
||||||
|
// send message
|
||||||
|
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if !ok { |
||||||
|
return nil, errors.New("ssh: tcpip-forward request denied by peer") |
||||||
|
} |
||||||
|
|
||||||
|
// If the original port was 0, then the remote side will
|
||||||
|
// supply a real port number in the response.
|
||||||
|
if laddr.Port == 0 { |
||||||
|
var p struct { |
||||||
|
Port uint32 |
||||||
|
} |
||||||
|
if err := Unmarshal(resp, &p); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
laddr.Port = int(p.Port) |
||||||
|
} |
||||||
|
|
||||||
|
// Register this forward, using the port number we obtained.
|
||||||
|
ch := c.forwards.add(*laddr) |
||||||
|
|
||||||
|
return &tcpListener{laddr, c, ch}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// forwardList stores a mapping between remote
|
||||||
|
// forward requests and the tcpListeners.
|
||||||
|
type forwardList struct { |
||||||
|
sync.Mutex |
||||||
|
entries []forwardEntry |
||||||
|
} |
||||||
|
|
||||||
|
// forwardEntry represents an established mapping of a laddr on a
|
||||||
|
// remote ssh server to a channel connected to a tcpListener.
|
||||||
|
type forwardEntry struct { |
||||||
|
laddr net.TCPAddr |
||||||
|
c chan forward |
||||||
|
} |
||||||
|
|
||||||
|
// forward represents an incoming forwarded tcpip connection. The
|
||||||
|
// arguments to add/remove/lookup should be address as specified in
|
||||||
|
// the original forward-request.
|
||||||
|
type forward struct { |
||||||
|
newCh NewChannel // the ssh client channel underlying this forward
|
||||||
|
raddr *net.TCPAddr // the raddr of the incoming connection
|
||||||
|
} |
||||||
|
|
||||||
|
func (l *forwardList) add(addr net.TCPAddr) chan forward { |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
f := forwardEntry{ |
||||||
|
addr, |
||||||
|
make(chan forward, 1), |
||||||
|
} |
||||||
|
l.entries = append(l.entries, f) |
||||||
|
return f.c |
||||||
|
} |
||||||
|
|
||||||
|
// See RFC 4254, section 7.2
|
||||||
|
type forwardedTCPPayload struct { |
||||||
|
Addr string |
||||||
|
Port uint32 |
||||||
|
OriginAddr string |
||||||
|
OriginPort uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
|
||||||
|
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { |
||||||
|
if port == 0 || port > 65535 { |
||||||
|
return nil, fmt.Errorf("ssh: port number out of range: %d", port) |
||||||
|
} |
||||||
|
ip := net.ParseIP(string(addr)) |
||||||
|
if ip == nil { |
||||||
|
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) |
||||||
|
} |
||||||
|
return &net.TCPAddr{IP: ip, Port: int(port)}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (l *forwardList) handleChannels(in <-chan NewChannel) { |
||||||
|
for ch := range in { |
||||||
|
var payload forwardedTCPPayload |
||||||
|
if err := Unmarshal(ch.ExtraData(), &payload); err != nil { |
||||||
|
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 section 7.2 specifies that incoming
|
||||||
|
// addresses should list the address, in string
|
||||||
|
// format. It is implied that this should be an IP
|
||||||
|
// address, as it would be impossible to connect to it
|
||||||
|
// otherwise.
|
||||||
|
laddr, err := parseTCPAddr(payload.Addr, payload.Port) |
||||||
|
if err != nil { |
||||||
|
ch.Reject(ConnectionFailed, err.Error()) |
||||||
|
continue |
||||||
|
} |
||||||
|
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) |
||||||
|
if err != nil { |
||||||
|
ch.Reject(ConnectionFailed, err.Error()) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if ok := l.forward(*laddr, *raddr, ch); !ok { |
||||||
|
// Section 7.2, implementations MUST reject spurious incoming
|
||||||
|
// connections.
|
||||||
|
ch.Reject(Prohibited, "no forward for address") |
||||||
|
continue |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// remove removes the forward entry, and the channel feeding its
|
||||||
|
// listener.
|
||||||
|
func (l *forwardList) remove(addr net.TCPAddr) { |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
for i, f := range l.entries { |
||||||
|
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { |
||||||
|
l.entries = append(l.entries[:i], l.entries[i+1:]...) |
||||||
|
close(f.c) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// closeAll closes and clears all forwards.
|
||||||
|
func (l *forwardList) closeAll() { |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
for _, f := range l.entries { |
||||||
|
close(f.c) |
||||||
|
} |
||||||
|
l.entries = nil |
||||||
|
} |
||||||
|
|
||||||
|
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
for _, f := range l.entries { |
||||||
|
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { |
||||||
|
f.c <- forward{ch, &raddr} |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
type tcpListener struct { |
||||||
|
laddr *net.TCPAddr |
||||||
|
|
||||||
|
conn *Client |
||||||
|
in <-chan forward |
||||||
|
} |
||||||
|
|
||||||
|
// Accept waits for and returns the next connection to the listener.
|
||||||
|
func (l *tcpListener) Accept() (net.Conn, error) { |
||||||
|
s, ok := <-l.in |
||||||
|
if !ok { |
||||||
|
return nil, io.EOF |
||||||
|
} |
||||||
|
ch, incoming, err := s.newCh.Accept() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
go DiscardRequests(incoming) |
||||||
|
|
||||||
|
return &tcpChanConn{ |
||||||
|
Channel: ch, |
||||||
|
laddr: l.laddr, |
||||||
|
raddr: s.raddr, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Close closes the listener.
|
||||||
|
func (l *tcpListener) Close() error { |
||||||
|
m := channelForwardMsg{ |
||||||
|
l.laddr.IP.String(), |
||||||
|
uint32(l.laddr.Port), |
||||||
|
} |
||||||
|
|
||||||
|
// this also closes the listener.
|
||||||
|
l.conn.forwards.remove(*l.laddr) |
||||||
|
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) |
||||||
|
if err == nil && !ok { |
||||||
|
err = errors.New("ssh: cancel-tcpip-forward failed") |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// Addr returns the listener's network address.
|
||||||
|
func (l *tcpListener) Addr() net.Addr { |
||||||
|
return l.laddr |
||||||
|
} |
||||||
|
|
||||||
|
// Dial initiates a connection to the addr from the remote host.
|
||||||
|
// The resulting connection has a zero LocalAddr() and RemoteAddr().
|
||||||
|
func (c *Client) Dial(n, addr string) (net.Conn, error) { |
||||||
|
// Parse the address into host and numeric port.
|
||||||
|
host, portString, err := net.SplitHostPort(addr) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
port, err := strconv.ParseUint(portString, 10, 16) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
// Use a zero address for local and remote address.
|
||||||
|
zeroAddr := &net.TCPAddr{ |
||||||
|
IP: net.IPv4zero, |
||||||
|
Port: 0, |
||||||
|
} |
||||||
|
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &tcpChanConn{ |
||||||
|
Channel: ch, |
||||||
|
laddr: zeroAddr, |
||||||
|
raddr: zeroAddr, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// DialTCP connects to the remote address raddr on the network net,
|
||||||
|
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
||||||
|
// as the local address for the connection.
|
||||||
|
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { |
||||||
|
if laddr == nil { |
||||||
|
laddr = &net.TCPAddr{ |
||||||
|
IP: net.IPv4zero, |
||||||
|
Port: 0, |
||||||
|
} |
||||||
|
} |
||||||
|
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &tcpChanConn{ |
||||||
|
Channel: ch, |
||||||
|
laddr: laddr, |
||||||
|
raddr: raddr, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// RFC 4254 7.2
|
||||||
|
type channelOpenDirectMsg struct { |
||||||
|
raddr string |
||||||
|
rport uint32 |
||||||
|
laddr string |
||||||
|
lport uint32 |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { |
||||||
|
msg := channelOpenDirectMsg{ |
||||||
|
raddr: raddr, |
||||||
|
rport: uint32(rport), |
||||||
|
laddr: laddr, |
||||||
|
lport: uint32(lport), |
||||||
|
} |
||||||
|
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
go DiscardRequests(in) |
||||||
|
return ch, err |
||||||
|
} |
||||||
|
|
||||||
|
type tcpChan struct { |
||||||
|
Channel // the backing channel
|
||||||
|
} |
||||||
|
|
||||||
|
// tcpChanConn fulfills the net.Conn interface without
|
||||||
|
// the tcpChan having to hold laddr or raddr directly.
|
||||||
|
type tcpChanConn struct { |
||||||
|
Channel |
||||||
|
laddr, raddr net.Addr |
||||||
|
} |
||||||
|
|
||||||
|
// LocalAddr returns the local network address.
|
||||||
|
func (t *tcpChanConn) LocalAddr() net.Addr { |
||||||
|
return t.laddr |
||||||
|
} |
||||||
|
|
||||||
|
// RemoteAddr returns the remote network address.
|
||||||
|
func (t *tcpChanConn) RemoteAddr() net.Addr { |
||||||
|
return t.raddr |
||||||
|
} |
||||||
|
|
||||||
|
// SetDeadline sets the read and write deadlines associated
|
||||||
|
// with the connection.
|
||||||
|
func (t *tcpChanConn) SetDeadline(deadline time.Time) error { |
||||||
|
if err := t.SetReadDeadline(deadline); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return t.SetWriteDeadline(deadline) |
||||||
|
} |
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline.
|
||||||
|
// A zero value for t means Read will not time out.
|
||||||
|
// After the deadline, the error from Read will implement net.Error
|
||||||
|
// with Timeout() == true.
|
||||||
|
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { |
||||||
|
return errors.New("ssh: tcpChan: deadline not supported") |
||||||
|
} |
||||||
|
|
||||||
|
// SetWriteDeadline exists to satisfy the net.Conn interface
|
||||||
|
// but is not implemented by this type. It always returns an error.
|
||||||
|
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { |
||||||
|
return errors.New("ssh: tcpChan: deadline not supported") |
||||||
|
} |
@ -0,0 +1,20 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestAutoPortListenBroken(t *testing.T) { |
||||||
|
broken := "SSH-2.0-OpenSSH_5.9hh11" |
||||||
|
works := "SSH-2.0-OpenSSH_6.1" |
||||||
|
if !isBrokenOpenSSHVersion(broken) { |
||||||
|
t.Errorf("version %q not marked as broken", broken) |
||||||
|
} |
||||||
|
if isBrokenOpenSSHVersion(works) { |
||||||
|
t.Errorf("version %q marked as broken", works) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,892 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package terminal |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"io" |
||||||
|
"sync" |
||||||
|
"unicode/utf8" |
||||||
|
) |
||||||
|
|
||||||
|
// EscapeCodes contains escape sequences that can be written to the terminal in
|
||||||
|
// order to achieve different styles of text.
|
||||||
|
type EscapeCodes struct { |
||||||
|
// Foreground colors
|
||||||
|
Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte |
||||||
|
|
||||||
|
// Reset all attributes
|
||||||
|
Reset []byte |
||||||
|
} |
||||||
|
|
||||||
|
var vt100EscapeCodes = EscapeCodes{ |
||||||
|
Black: []byte{keyEscape, '[', '3', '0', 'm'}, |
||||||
|
Red: []byte{keyEscape, '[', '3', '1', 'm'}, |
||||||
|
Green: []byte{keyEscape, '[', '3', '2', 'm'}, |
||||||
|
Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, |
||||||
|
Blue: []byte{keyEscape, '[', '3', '4', 'm'}, |
||||||
|
Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, |
||||||
|
Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, |
||||||
|
White: []byte{keyEscape, '[', '3', '7', 'm'}, |
||||||
|
|
||||||
|
Reset: []byte{keyEscape, '[', '0', 'm'}, |
||||||
|
} |
||||||
|
|
||||||
|
// Terminal contains the state for running a VT100 terminal that is capable of
|
||||||
|
// reading lines of input.
|
||||||
|
type Terminal struct { |
||||||
|
// AutoCompleteCallback, if non-null, is called for each keypress with
|
||||||
|
// the full input line and the current position of the cursor (in
|
||||||
|
// bytes, as an index into |line|). If it returns ok=false, the key
|
||||||
|
// press is processed normally. Otherwise it returns a replacement line
|
||||||
|
// and the new cursor position.
|
||||||
|
AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) |
||||||
|
|
||||||
|
// Escape contains a pointer to the escape codes for this terminal.
|
||||||
|
// It's always a valid pointer, although the escape codes themselves
|
||||||
|
// may be empty if the terminal doesn't support them.
|
||||||
|
Escape *EscapeCodes |
||||||
|
|
||||||
|
// lock protects the terminal and the state in this object from
|
||||||
|
// concurrent processing of a key press and a Write() call.
|
||||||
|
lock sync.Mutex |
||||||
|
|
||||||
|
c io.ReadWriter |
||||||
|
prompt []rune |
||||||
|
|
||||||
|
// line is the current line being entered.
|
||||||
|
line []rune |
||||||
|
// pos is the logical position of the cursor in line
|
||||||
|
pos int |
||||||
|
// echo is true if local echo is enabled
|
||||||
|
echo bool |
||||||
|
// pasteActive is true iff there is a bracketed paste operation in
|
||||||
|
// progress.
|
||||||
|
pasteActive bool |
||||||
|
|
||||||
|
// cursorX contains the current X value of the cursor where the left
|
||||||
|
// edge is 0. cursorY contains the row number where the first row of
|
||||||
|
// the current line is 0.
|
||||||
|
cursorX, cursorY int |
||||||
|
// maxLine is the greatest value of cursorY so far.
|
||||||
|
maxLine int |
||||||
|
|
||||||
|
termWidth, termHeight int |
||||||
|
|
||||||
|
// outBuf contains the terminal data to be sent.
|
||||||
|
outBuf []byte |
||||||
|
// remainder contains the remainder of any partial key sequences after
|
||||||
|
// a read. It aliases into inBuf.
|
||||||
|
remainder []byte |
||||||
|
inBuf [256]byte |
||||||
|
|
||||||
|
// history contains previously entered commands so that they can be
|
||||||
|
// accessed with the up and down keys.
|
||||||
|
history stRingBuffer |
||||||
|
// historyIndex stores the currently accessed history entry, where zero
|
||||||
|
// means the immediately previous entry.
|
||||||
|
historyIndex int |
||||||
|
// When navigating up and down the history it's possible to return to
|
||||||
|
// the incomplete, initial line. That value is stored in
|
||||||
|
// historyPending.
|
||||||
|
historyPending string |
||||||
|
} |
||||||
|
|
||||||
|
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
||||||
|
// a local terminal, that terminal must first have been put into raw mode.
|
||||||
|
// prompt is a string that is written at the start of each input line (i.e.
|
||||||
|
// "> ").
|
||||||
|
func NewTerminal(c io.ReadWriter, prompt string) *Terminal { |
||||||
|
return &Terminal{ |
||||||
|
Escape: &vt100EscapeCodes, |
||||||
|
c: c, |
||||||
|
prompt: []rune(prompt), |
||||||
|
termWidth: 80, |
||||||
|
termHeight: 24, |
||||||
|
echo: true, |
||||||
|
historyIndex: -1, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
const ( |
||||||
|
keyCtrlD = 4 |
||||||
|
keyCtrlU = 21 |
||||||
|
keyEnter = '\r' |
||||||
|
keyEscape = 27 |
||||||
|
keyBackspace = 127 |
||||||
|
keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota |
||||||
|
keyUp |
||||||
|
keyDown |
||||||
|
keyLeft |
||||||
|
keyRight |
||||||
|
keyAltLeft |
||||||
|
keyAltRight |
||||||
|
keyHome |
||||||
|
keyEnd |
||||||
|
keyDeleteWord |
||||||
|
keyDeleteLine |
||||||
|
keyClearScreen |
||||||
|
keyPasteStart |
||||||
|
keyPasteEnd |
||||||
|
) |
||||||
|
|
||||||
|
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} |
||||||
|
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} |
||||||
|
|
||||||
|
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
||||||
|
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
|
||||||
|
func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { |
||||||
|
if len(b) == 0 { |
||||||
|
return utf8.RuneError, nil |
||||||
|
} |
||||||
|
|
||||||
|
if !pasteActive { |
||||||
|
switch b[0] { |
||||||
|
case 1: // ^A
|
||||||
|
return keyHome, b[1:] |
||||||
|
case 5: // ^E
|
||||||
|
return keyEnd, b[1:] |
||||||
|
case 8: // ^H
|
||||||
|
return keyBackspace, b[1:] |
||||||
|
case 11: // ^K
|
||||||
|
return keyDeleteLine, b[1:] |
||||||
|
case 12: // ^L
|
||||||
|
return keyClearScreen, b[1:] |
||||||
|
case 23: // ^W
|
||||||
|
return keyDeleteWord, b[1:] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if b[0] != keyEscape { |
||||||
|
if !utf8.FullRune(b) { |
||||||
|
return utf8.RuneError, b |
||||||
|
} |
||||||
|
r, l := utf8.DecodeRune(b) |
||||||
|
return r, b[l:] |
||||||
|
} |
||||||
|
|
||||||
|
if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { |
||||||
|
switch b[2] { |
||||||
|
case 'A': |
||||||
|
return keyUp, b[3:] |
||||||
|
case 'B': |
||||||
|
return keyDown, b[3:] |
||||||
|
case 'C': |
||||||
|
return keyRight, b[3:] |
||||||
|
case 'D': |
||||||
|
return keyLeft, b[3:] |
||||||
|
case 'H': |
||||||
|
return keyHome, b[3:] |
||||||
|
case 'F': |
||||||
|
return keyEnd, b[3:] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { |
||||||
|
switch b[5] { |
||||||
|
case 'C': |
||||||
|
return keyAltRight, b[6:] |
||||||
|
case 'D': |
||||||
|
return keyAltLeft, b[6:] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { |
||||||
|
return keyPasteStart, b[6:] |
||||||
|
} |
||||||
|
|
||||||
|
if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { |
||||||
|
return keyPasteEnd, b[6:] |
||||||
|
} |
||||||
|
|
||||||
|
// If we get here then we have a key that we don't recognise, or a
|
||||||
|
// partial sequence. It's not clear how one should find the end of a
|
||||||
|
// sequence without knowing them all, but it seems that [a-zA-Z~] only
|
||||||
|
// appears at the end of a sequence.
|
||||||
|
for i, c := range b[0:] { |
||||||
|
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { |
||||||
|
return keyUnknown, b[i+1:] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return utf8.RuneError, b |
||||||
|
} |
||||||
|
|
||||||
|
// queue appends data to the end of t.outBuf
|
||||||
|
func (t *Terminal) queue(data []rune) { |
||||||
|
t.outBuf = append(t.outBuf, []byte(string(data))...) |
||||||
|
} |
||||||
|
|
||||||
|
var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} |
||||||
|
var space = []rune{' '} |
||||||
|
|
||||||
|
func isPrintable(key rune) bool { |
||||||
|
isInSurrogateArea := key >= 0xd800 && key <= 0xdbff |
||||||
|
return key >= 32 && !isInSurrogateArea |
||||||
|
} |
||||||
|
|
||||||
|
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
|
||||||
|
// given, logical position in the text.
|
||||||
|
func (t *Terminal) moveCursorToPos(pos int) { |
||||||
|
if !t.echo { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
x := visualLength(t.prompt) + pos |
||||||
|
y := x / t.termWidth |
||||||
|
x = x % t.termWidth |
||||||
|
|
||||||
|
up := 0 |
||||||
|
if y < t.cursorY { |
||||||
|
up = t.cursorY - y |
||||||
|
} |
||||||
|
|
||||||
|
down := 0 |
||||||
|
if y > t.cursorY { |
||||||
|
down = y - t.cursorY |
||||||
|
} |
||||||
|
|
||||||
|
left := 0 |
||||||
|
if x < t.cursorX { |
||||||
|
left = t.cursorX - x |
||||||
|
} |
||||||
|
|
||||||
|
right := 0 |
||||||
|
if x > t.cursorX { |
||||||
|
right = x - t.cursorX |
||||||
|
} |
||||||
|
|
||||||
|
t.cursorX = x |
||||||
|
t.cursorY = y |
||||||
|
t.move(up, down, left, right) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) move(up, down, left, right int) { |
||||||
|
movement := make([]rune, 3*(up+down+left+right)) |
||||||
|
m := movement |
||||||
|
for i := 0; i < up; i++ { |
||||||
|
m[0] = keyEscape |
||||||
|
m[1] = '[' |
||||||
|
m[2] = 'A' |
||||||
|
m = m[3:] |
||||||
|
} |
||||||
|
for i := 0; i < down; i++ { |
||||||
|
m[0] = keyEscape |
||||||
|
m[1] = '[' |
||||||
|
m[2] = 'B' |
||||||
|
m = m[3:] |
||||||
|
} |
||||||
|
for i := 0; i < left; i++ { |
||||||
|
m[0] = keyEscape |
||||||
|
m[1] = '[' |
||||||
|
m[2] = 'D' |
||||||
|
m = m[3:] |
||||||
|
} |
||||||
|
for i := 0; i < right; i++ { |
||||||
|
m[0] = keyEscape |
||||||
|
m[1] = '[' |
||||||
|
m[2] = 'C' |
||||||
|
m = m[3:] |
||||||
|
} |
||||||
|
|
||||||
|
t.queue(movement) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) clearLineToRight() { |
||||||
|
op := []rune{keyEscape, '[', 'K'} |
||||||
|
t.queue(op) |
||||||
|
} |
||||||
|
|
||||||
|
const maxLineLength = 4096 |
||||||
|
|
||||||
|
func (t *Terminal) setLine(newLine []rune, newPos int) { |
||||||
|
if t.echo { |
||||||
|
t.moveCursorToPos(0) |
||||||
|
t.writeLine(newLine) |
||||||
|
for i := len(newLine); i < len(t.line); i++ { |
||||||
|
t.writeLine(space) |
||||||
|
} |
||||||
|
t.moveCursorToPos(newPos) |
||||||
|
} |
||||||
|
t.line = newLine |
||||||
|
t.pos = newPos |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) advanceCursor(places int) { |
||||||
|
t.cursorX += places |
||||||
|
t.cursorY += t.cursorX / t.termWidth |
||||||
|
if t.cursorY > t.maxLine { |
||||||
|
t.maxLine = t.cursorY |
||||||
|
} |
||||||
|
t.cursorX = t.cursorX % t.termWidth |
||||||
|
|
||||||
|
if places > 0 && t.cursorX == 0 { |
||||||
|
// Normally terminals will advance the current position
|
||||||
|
// when writing a character. But that doesn't happen
|
||||||
|
// for the last character in a line. However, when
|
||||||
|
// writing a character (except a new line) that causes
|
||||||
|
// a line wrap, the position will be advanced two
|
||||||
|
// places.
|
||||||
|
//
|
||||||
|
// So, if we are stopping at the end of a line, we
|
||||||
|
// need to write a newline so that our cursor can be
|
||||||
|
// advanced to the next line.
|
||||||
|
t.outBuf = append(t.outBuf, '\n') |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) eraseNPreviousChars(n int) { |
||||||
|
if n == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if t.pos < n { |
||||||
|
n = t.pos |
||||||
|
} |
||||||
|
t.pos -= n |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
|
||||||
|
copy(t.line[t.pos:], t.line[n+t.pos:]) |
||||||
|
t.line = t.line[:len(t.line)-n] |
||||||
|
if t.echo { |
||||||
|
t.writeLine(t.line[t.pos:]) |
||||||
|
for i := 0; i < n; i++ { |
||||||
|
t.queue(space) |
||||||
|
} |
||||||
|
t.advanceCursor(n) |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// countToLeftWord returns then number of characters from the cursor to the
|
||||||
|
// start of the previous word.
|
||||||
|
func (t *Terminal) countToLeftWord() int { |
||||||
|
if t.pos == 0 { |
||||||
|
return 0 |
||||||
|
} |
||||||
|
|
||||||
|
pos := t.pos - 1 |
||||||
|
for pos > 0 { |
||||||
|
if t.line[pos] != ' ' { |
||||||
|
break |
||||||
|
} |
||||||
|
pos-- |
||||||
|
} |
||||||
|
for pos > 0 { |
||||||
|
if t.line[pos] == ' ' { |
||||||
|
pos++ |
||||||
|
break |
||||||
|
} |
||||||
|
pos-- |
||||||
|
} |
||||||
|
|
||||||
|
return t.pos - pos |
||||||
|
} |
||||||
|
|
||||||
|
// countToRightWord returns then number of characters from the cursor to the
|
||||||
|
// start of the next word.
|
||||||
|
func (t *Terminal) countToRightWord() int { |
||||||
|
pos := t.pos |
||||||
|
for pos < len(t.line) { |
||||||
|
if t.line[pos] == ' ' { |
||||||
|
break |
||||||
|
} |
||||||
|
pos++ |
||||||
|
} |
||||||
|
for pos < len(t.line) { |
||||||
|
if t.line[pos] != ' ' { |
||||||
|
break |
||||||
|
} |
||||||
|
pos++ |
||||||
|
} |
||||||
|
return pos - t.pos |
||||||
|
} |
||||||
|
|
||||||
|
// visualLength returns the number of visible glyphs in s.
|
||||||
|
func visualLength(runes []rune) int { |
||||||
|
inEscapeSeq := false |
||||||
|
length := 0 |
||||||
|
|
||||||
|
for _, r := range runes { |
||||||
|
switch { |
||||||
|
case inEscapeSeq: |
||||||
|
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { |
||||||
|
inEscapeSeq = false |
||||||
|
} |
||||||
|
case r == '\x1b': |
||||||
|
inEscapeSeq = true |
||||||
|
default: |
||||||
|
length++ |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return length |
||||||
|
} |
||||||
|
|
||||||
|
// handleKey processes the given key and, optionally, returns a line of text
|
||||||
|
// that the user has entered.
|
||||||
|
func (t *Terminal) handleKey(key rune) (line string, ok bool) { |
||||||
|
if t.pasteActive && key != keyEnter { |
||||||
|
t.addKeyToLine(key) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
switch key { |
||||||
|
case keyBackspace: |
||||||
|
if t.pos == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
t.eraseNPreviousChars(1) |
||||||
|
case keyAltLeft: |
||||||
|
// move left by a word.
|
||||||
|
t.pos -= t.countToLeftWord() |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyAltRight: |
||||||
|
// move right by a word.
|
||||||
|
t.pos += t.countToRightWord() |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyLeft: |
||||||
|
if t.pos == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
t.pos-- |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyRight: |
||||||
|
if t.pos == len(t.line) { |
||||||
|
return |
||||||
|
} |
||||||
|
t.pos++ |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyHome: |
||||||
|
if t.pos == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
t.pos = 0 |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyEnd: |
||||||
|
if t.pos == len(t.line) { |
||||||
|
return |
||||||
|
} |
||||||
|
t.pos = len(t.line) |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyUp: |
||||||
|
entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) |
||||||
|
if !ok { |
||||||
|
return "", false |
||||||
|
} |
||||||
|
if t.historyIndex == -1 { |
||||||
|
t.historyPending = string(t.line) |
||||||
|
} |
||||||
|
t.historyIndex++ |
||||||
|
runes := []rune(entry) |
||||||
|
t.setLine(runes, len(runes)) |
||||||
|
case keyDown: |
||||||
|
switch t.historyIndex { |
||||||
|
case -1: |
||||||
|
return |
||||||
|
case 0: |
||||||
|
runes := []rune(t.historyPending) |
||||||
|
t.setLine(runes, len(runes)) |
||||||
|
t.historyIndex-- |
||||||
|
default: |
||||||
|
entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) |
||||||
|
if ok { |
||||||
|
t.historyIndex-- |
||||||
|
runes := []rune(entry) |
||||||
|
t.setLine(runes, len(runes)) |
||||||
|
} |
||||||
|
} |
||||||
|
case keyEnter: |
||||||
|
t.moveCursorToPos(len(t.line)) |
||||||
|
t.queue([]rune("\r\n")) |
||||||
|
line = string(t.line) |
||||||
|
ok = true |
||||||
|
t.line = t.line[:0] |
||||||
|
t.pos = 0 |
||||||
|
t.cursorX = 0 |
||||||
|
t.cursorY = 0 |
||||||
|
t.maxLine = 0 |
||||||
|
case keyDeleteWord: |
||||||
|
// Delete zero or more spaces and then one or more characters.
|
||||||
|
t.eraseNPreviousChars(t.countToLeftWord()) |
||||||
|
case keyDeleteLine: |
||||||
|
// Delete everything from the current cursor position to the
|
||||||
|
// end of line.
|
||||||
|
for i := t.pos; i < len(t.line); i++ { |
||||||
|
t.queue(space) |
||||||
|
t.advanceCursor(1) |
||||||
|
} |
||||||
|
t.line = t.line[:t.pos] |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
case keyCtrlD: |
||||||
|
// Erase the character under the current position.
|
||||||
|
// The EOF case when the line is empty is handled in
|
||||||
|
// readLine().
|
||||||
|
if t.pos < len(t.line) { |
||||||
|
t.pos++ |
||||||
|
t.eraseNPreviousChars(1) |
||||||
|
} |
||||||
|
case keyCtrlU: |
||||||
|
t.eraseNPreviousChars(t.pos) |
||||||
|
case keyClearScreen: |
||||||
|
// Erases the screen and moves the cursor to the home position.
|
||||||
|
t.queue([]rune("\x1b[2J\x1b[H")) |
||||||
|
t.queue(t.prompt) |
||||||
|
t.cursorX, t.cursorY = 0, 0 |
||||||
|
t.advanceCursor(visualLength(t.prompt)) |
||||||
|
t.setLine(t.line, t.pos) |
||||||
|
default: |
||||||
|
if t.AutoCompleteCallback != nil { |
||||||
|
prefix := string(t.line[:t.pos]) |
||||||
|
suffix := string(t.line[t.pos:]) |
||||||
|
|
||||||
|
t.lock.Unlock() |
||||||
|
newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) |
||||||
|
t.lock.Lock() |
||||||
|
|
||||||
|
if completeOk { |
||||||
|
t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
if !isPrintable(key) { |
||||||
|
return |
||||||
|
} |
||||||
|
if len(t.line) == maxLineLength { |
||||||
|
return |
||||||
|
} |
||||||
|
t.addKeyToLine(key) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// addKeyToLine inserts the given key at the current position in the current
|
||||||
|
// line.
|
||||||
|
func (t *Terminal) addKeyToLine(key rune) { |
||||||
|
if len(t.line) == cap(t.line) { |
||||||
|
newLine := make([]rune, len(t.line), 2*(1+len(t.line))) |
||||||
|
copy(newLine, t.line) |
||||||
|
t.line = newLine |
||||||
|
} |
||||||
|
t.line = t.line[:len(t.line)+1] |
||||||
|
copy(t.line[t.pos+1:], t.line[t.pos:]) |
||||||
|
t.line[t.pos] = key |
||||||
|
if t.echo { |
||||||
|
t.writeLine(t.line[t.pos:]) |
||||||
|
} |
||||||
|
t.pos++ |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) writeLine(line []rune) { |
||||||
|
for len(line) != 0 { |
||||||
|
remainingOnLine := t.termWidth - t.cursorX |
||||||
|
todo := len(line) |
||||||
|
if todo > remainingOnLine { |
||||||
|
todo = remainingOnLine |
||||||
|
} |
||||||
|
t.queue(line[:todo]) |
||||||
|
t.advanceCursor(visualLength(line[:todo])) |
||||||
|
line = line[todo:] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) Write(buf []byte) (n int, err error) { |
||||||
|
t.lock.Lock() |
||||||
|
defer t.lock.Unlock() |
||||||
|
|
||||||
|
if t.cursorX == 0 && t.cursorY == 0 { |
||||||
|
// This is the easy case: there's nothing on the screen that we
|
||||||
|
// have to move out of the way.
|
||||||
|
return t.c.Write(buf) |
||||||
|
} |
||||||
|
|
||||||
|
// We have a prompt and possibly user input on the screen. We
|
||||||
|
// have to clear it first.
|
||||||
|
t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) |
||||||
|
t.cursorX = 0 |
||||||
|
t.clearLineToRight() |
||||||
|
|
||||||
|
for t.cursorY > 0 { |
||||||
|
t.move(1 /* up */, 0, 0, 0) |
||||||
|
t.cursorY-- |
||||||
|
t.clearLineToRight() |
||||||
|
} |
||||||
|
|
||||||
|
if _, err = t.c.Write(t.outBuf); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
t.outBuf = t.outBuf[:0] |
||||||
|
|
||||||
|
if n, err = t.c.Write(buf); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
t.writeLine(t.prompt) |
||||||
|
if t.echo { |
||||||
|
t.writeLine(t.line) |
||||||
|
} |
||||||
|
|
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
|
||||||
|
if _, err = t.c.Write(t.outBuf); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
t.outBuf = t.outBuf[:0] |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// ReadPassword temporarily changes the prompt and reads a password, without
|
||||||
|
// echo, from the terminal.
|
||||||
|
func (t *Terminal) ReadPassword(prompt string) (line string, err error) { |
||||||
|
t.lock.Lock() |
||||||
|
defer t.lock.Unlock() |
||||||
|
|
||||||
|
oldPrompt := t.prompt |
||||||
|
t.prompt = []rune(prompt) |
||||||
|
t.echo = false |
||||||
|
|
||||||
|
line, err = t.readLine() |
||||||
|
|
||||||
|
t.prompt = oldPrompt |
||||||
|
t.echo = true |
||||||
|
|
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// ReadLine returns a line of input from the terminal.
|
||||||
|
func (t *Terminal) ReadLine() (line string, err error) { |
||||||
|
t.lock.Lock() |
||||||
|
defer t.lock.Unlock() |
||||||
|
|
||||||
|
return t.readLine() |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) readLine() (line string, err error) { |
||||||
|
// t.lock must be held at this point
|
||||||
|
|
||||||
|
if t.cursorX == 0 && t.cursorY == 0 { |
||||||
|
t.writeLine(t.prompt) |
||||||
|
t.c.Write(t.outBuf) |
||||||
|
t.outBuf = t.outBuf[:0] |
||||||
|
} |
||||||
|
|
||||||
|
lineIsPasted := t.pasteActive |
||||||
|
|
||||||
|
for { |
||||||
|
rest := t.remainder |
||||||
|
lineOk := false |
||||||
|
for !lineOk { |
||||||
|
var key rune |
||||||
|
key, rest = bytesToKey(rest, t.pasteActive) |
||||||
|
if key == utf8.RuneError { |
||||||
|
break |
||||||
|
} |
||||||
|
if !t.pasteActive { |
||||||
|
if key == keyCtrlD { |
||||||
|
if len(t.line) == 0 { |
||||||
|
return "", io.EOF |
||||||
|
} |
||||||
|
} |
||||||
|
if key == keyPasteStart { |
||||||
|
t.pasteActive = true |
||||||
|
if len(t.line) == 0 { |
||||||
|
lineIsPasted = true |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
} else if key == keyPasteEnd { |
||||||
|
t.pasteActive = false |
||||||
|
continue |
||||||
|
} |
||||||
|
if !t.pasteActive { |
||||||
|
lineIsPasted = false |
||||||
|
} |
||||||
|
line, lineOk = t.handleKey(key) |
||||||
|
} |
||||||
|
if len(rest) > 0 { |
||||||
|
n := copy(t.inBuf[:], rest) |
||||||
|
t.remainder = t.inBuf[:n] |
||||||
|
} else { |
||||||
|
t.remainder = nil |
||||||
|
} |
||||||
|
t.c.Write(t.outBuf) |
||||||
|
t.outBuf = t.outBuf[:0] |
||||||
|
if lineOk { |
||||||
|
if t.echo { |
||||||
|
t.historyIndex = -1 |
||||||
|
t.history.Add(line) |
||||||
|
} |
||||||
|
if lineIsPasted { |
||||||
|
err = ErrPasteIndicator |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// t.remainder is a slice at the beginning of t.inBuf
|
||||||
|
// containing a partial key sequence
|
||||||
|
readBuf := t.inBuf[len(t.remainder):] |
||||||
|
var n int |
||||||
|
|
||||||
|
t.lock.Unlock() |
||||||
|
n, err = t.c.Read(readBuf) |
||||||
|
t.lock.Lock() |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
t.remainder = t.inBuf[:n+len(t.remainder)] |
||||||
|
} |
||||||
|
|
||||||
|
panic("unreachable") // for Go 1.0.
|
||||||
|
} |
||||||
|
|
||||||
|
// SetPrompt sets the prompt to be used when reading subsequent lines.
|
||||||
|
func (t *Terminal) SetPrompt(prompt string) { |
||||||
|
t.lock.Lock() |
||||||
|
defer t.lock.Unlock() |
||||||
|
|
||||||
|
t.prompt = []rune(prompt) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { |
||||||
|
// Move cursor to column zero at the start of the line.
|
||||||
|
t.move(t.cursorY, 0, t.cursorX, 0) |
||||||
|
t.cursorX, t.cursorY = 0, 0 |
||||||
|
t.clearLineToRight() |
||||||
|
for t.cursorY < numPrevLines { |
||||||
|
// Move down a line
|
||||||
|
t.move(0, 1, 0, 0) |
||||||
|
t.cursorY++ |
||||||
|
t.clearLineToRight() |
||||||
|
} |
||||||
|
// Move back to beginning.
|
||||||
|
t.move(t.cursorY, 0, 0, 0) |
||||||
|
t.cursorX, t.cursorY = 0, 0 |
||||||
|
|
||||||
|
t.queue(t.prompt) |
||||||
|
t.advanceCursor(visualLength(t.prompt)) |
||||||
|
t.writeLine(t.line) |
||||||
|
t.moveCursorToPos(t.pos) |
||||||
|
} |
||||||
|
|
||||||
|
func (t *Terminal) SetSize(width, height int) error { |
||||||
|
t.lock.Lock() |
||||||
|
defer t.lock.Unlock() |
||||||
|
|
||||||
|
if width == 0 { |
||||||
|
width = 1 |
||||||
|
} |
||||||
|
|
||||||
|
oldWidth := t.termWidth |
||||||
|
t.termWidth, t.termHeight = width, height |
||||||
|
|
||||||
|
switch { |
||||||
|
case width == oldWidth: |
||||||
|
// If the width didn't change then nothing else needs to be
|
||||||
|
// done.
|
||||||
|
return nil |
||||||
|
case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: |
||||||
|
// If there is nothing on current line and no prompt printed,
|
||||||
|
// just do nothing
|
||||||
|
return nil |
||||||
|
case width < oldWidth: |
||||||
|
// Some terminals (e.g. xterm) will truncate lines that were
|
||||||
|
// too long when shinking. Others, (e.g. gnome-terminal) will
|
||||||
|
// attempt to wrap them. For the former, repainting t.maxLine
|
||||||
|
// works great, but that behaviour goes badly wrong in the case
|
||||||
|
// of the latter because they have doubled every full line.
|
||||||
|
|
||||||
|
// We assume that we are working on a terminal that wraps lines
|
||||||
|
// and adjust the cursor position based on every previous line
|
||||||
|
// wrapping and turning into two. This causes the prompt on
|
||||||
|
// xterms to move upwards, which isn't great, but it avoids a
|
||||||
|
// huge mess with gnome-terminal.
|
||||||
|
if t.cursorX >= t.termWidth { |
||||||
|
t.cursorX = t.termWidth - 1 |
||||||
|
} |
||||||
|
t.cursorY *= 2 |
||||||
|
t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) |
||||||
|
case width > oldWidth: |
||||||
|
// If the terminal expands then our position calculations will
|
||||||
|
// be wrong in the future because we think the cursor is
|
||||||
|
// |t.pos| chars into the string, but there will be a gap at
|
||||||
|
// the end of any wrapped line.
|
||||||
|
//
|
||||||
|
// But the position will actually be correct until we move, so
|
||||||
|
// we can move back to the beginning and repaint everything.
|
||||||
|
t.clearAndRepaintLinePlusNPrevious(t.maxLine) |
||||||
|
} |
||||||
|
|
||||||
|
_, err := t.c.Write(t.outBuf) |
||||||
|
t.outBuf = t.outBuf[:0] |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
type pasteIndicatorError struct{} |
||||||
|
|
||||||
|
func (pasteIndicatorError) Error() string { |
||||||
|
return "terminal: ErrPasteIndicator not correctly handled" |
||||||
|
} |
||||||
|
|
||||||
|
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
|
||||||
|
// to valid line data. It indicates that bracketed paste mode is enabled and
|
||||||
|
// that the returned line consists only of pasted data. Programs may wish to
|
||||||
|
// interpret pasted data more literally than typed data.
|
||||||
|
var ErrPasteIndicator = pasteIndicatorError{} |
||||||
|
|
||||||
|
// SetBracketedPasteMode requests that the terminal bracket paste operations
|
||||||
|
// with markers. Not all terminals support this but, if it is supported, then
|
||||||
|
// enabling this mode will stop any autocomplete callback from running due to
|
||||||
|
// pastes. Additionally, any lines that are completely pasted will be returned
|
||||||
|
// from ReadLine with the error set to ErrPasteIndicator.
|
||||||
|
func (t *Terminal) SetBracketedPasteMode(on bool) { |
||||||
|
if on { |
||||||
|
io.WriteString(t.c, "\x1b[?2004h") |
||||||
|
} else { |
||||||
|
io.WriteString(t.c, "\x1b[?2004l") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// stRingBuffer is a ring buffer of strings.
|
||||||
|
type stRingBuffer struct { |
||||||
|
// entries contains max elements.
|
||||||
|
entries []string |
||||||
|
max int |
||||||
|
// head contains the index of the element most recently added to the ring.
|
||||||
|
head int |
||||||
|
// size contains the number of elements in the ring.
|
||||||
|
size int |
||||||
|
} |
||||||
|
|
||||||
|
func (s *stRingBuffer) Add(a string) { |
||||||
|
if s.entries == nil { |
||||||
|
const defaultNumEntries = 100 |
||||||
|
s.entries = make([]string, defaultNumEntries) |
||||||
|
s.max = defaultNumEntries |
||||||
|
} |
||||||
|
|
||||||
|
s.head = (s.head + 1) % s.max |
||||||
|
s.entries[s.head] = a |
||||||
|
if s.size < s.max { |
||||||
|
s.size++ |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// NthPreviousEntry returns the value passed to the nth previous call to Add.
|
||||||
|
// If n is zero then the immediately prior value is returned, if one, then the
|
||||||
|
// next most recent, and so on. If such an element doesn't exist then ok is
|
||||||
|
// false.
|
||||||
|
func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { |
||||||
|
if n >= s.size { |
||||||
|
return "", false |
||||||
|
} |
||||||
|
index := s.head - n |
||||||
|
if index < 0 { |
||||||
|
index += s.max |
||||||
|
} |
||||||
|
return s.entries[index], true |
||||||
|
} |
@ -0,0 +1,269 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package terminal |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
type MockTerminal struct { |
||||||
|
toSend []byte |
||||||
|
bytesPerRead int |
||||||
|
received []byte |
||||||
|
} |
||||||
|
|
||||||
|
func (c *MockTerminal) Read(data []byte) (n int, err error) { |
||||||
|
n = len(data) |
||||||
|
if n == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
if n > len(c.toSend) { |
||||||
|
n = len(c.toSend) |
||||||
|
} |
||||||
|
if n == 0 { |
||||||
|
return 0, io.EOF |
||||||
|
} |
||||||
|
if c.bytesPerRead > 0 && n > c.bytesPerRead { |
||||||
|
n = c.bytesPerRead |
||||||
|
} |
||||||
|
copy(data, c.toSend[:n]) |
||||||
|
c.toSend = c.toSend[n:] |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *MockTerminal) Write(data []byte) (n int, err error) { |
||||||
|
c.received = append(c.received, data...) |
||||||
|
return len(data), nil |
||||||
|
} |
||||||
|
|
||||||
|
func TestClose(t *testing.T) { |
||||||
|
c := &MockTerminal{} |
||||||
|
ss := NewTerminal(c, "> ") |
||||||
|
line, err := ss.ReadLine() |
||||||
|
if line != "" { |
||||||
|
t.Errorf("Expected empty line but got: %s", line) |
||||||
|
} |
||||||
|
if err != io.EOF { |
||||||
|
t.Errorf("Error should have been EOF but got: %s", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
var keyPressTests = []struct { |
||||||
|
in string |
||||||
|
line string |
||||||
|
err error |
||||||
|
throwAwayLines int |
||||||
|
}{ |
||||||
|
{ |
||||||
|
err: io.EOF, |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "foo\r", |
||||||
|
line: "foo", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a\x1b[Cb\r", // right
|
||||||
|
line: "ab", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a\x1b[Db\r", // left
|
||||||
|
line: "ba", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a\177b\r", // backspace
|
||||||
|
line: "b", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "\x1b[A\r", // up
|
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "\x1b[B\r", // down
|
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "line\x1b[A\x1b[B\r", // up then down
|
||||||
|
line: "line", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "line1\rline2\x1b[A\r", // recall previous line.
|
||||||
|
line: "line1", |
||||||
|
throwAwayLines: 1, |
||||||
|
}, |
||||||
|
{ |
||||||
|
// recall two previous lines and append.
|
||||||
|
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r", |
||||||
|
line: "line1xxx", |
||||||
|
throwAwayLines: 2, |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Ctrl-A to move to beginning of line followed by ^K to kill
|
||||||
|
// line.
|
||||||
|
in: "a b \001\013\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
|
||||||
|
// finally ^K to kill nothing.
|
||||||
|
in: "a b \001\005\013\r", |
||||||
|
line: "a b ", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "\027\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a\027\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a \027\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a b\027\r", |
||||||
|
line: "a ", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a b \027\r", |
||||||
|
line: "a ", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "one two thr\x1b[D\027\r", |
||||||
|
line: "one two r", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "\013\r", |
||||||
|
line: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "a\013\r", |
||||||
|
line: "a", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "ab\x1b[D\013\r", |
||||||
|
line: "a", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "Ξεσκεπάζω\r", |
||||||
|
line: "Ξεσκεπάζω", |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
|
||||||
|
line: "", |
||||||
|
throwAwayLines: 1, |
||||||
|
}, |
||||||
|
{ |
||||||
|
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
|
||||||
|
line: "£", |
||||||
|
throwAwayLines: 1, |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Ctrl-D at the end of the line should be ignored.
|
||||||
|
in: "a\004\r", |
||||||
|
line: "a", |
||||||
|
}, |
||||||
|
{ |
||||||
|
// a, b, left, Ctrl-D should erase the b.
|
||||||
|
in: "ab\x1b[D\004\r", |
||||||
|
line: "a", |
||||||
|
}, |
||||||
|
{ |
||||||
|
// a, b, c, d, left, left, ^U should erase to the beginning of
|
||||||
|
// the line.
|
||||||
|
in: "abcd\x1b[D\x1b[D\025\r", |
||||||
|
line: "cd", |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Bracketed paste mode: control sequences should be returned
|
||||||
|
// verbatim in paste mode.
|
||||||
|
in: "abc\x1b[200~de\177f\x1b[201~\177\r", |
||||||
|
line: "abcde\177", |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Enter in bracketed paste mode should still work.
|
||||||
|
in: "abc\x1b[200~d\refg\x1b[201~h\r", |
||||||
|
line: "efgh", |
||||||
|
throwAwayLines: 1, |
||||||
|
}, |
||||||
|
{ |
||||||
|
// Lines consisting entirely of pasted data should be indicated as such.
|
||||||
|
in: "\x1b[200~a\r", |
||||||
|
line: "a", |
||||||
|
err: ErrPasteIndicator, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
func TestKeyPresses(t *testing.T) { |
||||||
|
for i, test := range keyPressTests { |
||||||
|
for j := 1; j < len(test.in); j++ { |
||||||
|
c := &MockTerminal{ |
||||||
|
toSend: []byte(test.in), |
||||||
|
bytesPerRead: j, |
||||||
|
} |
||||||
|
ss := NewTerminal(c, "> ") |
||||||
|
for k := 0; k < test.throwAwayLines; k++ { |
||||||
|
_, err := ss.ReadLine() |
||||||
|
if err != nil { |
||||||
|
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err) |
||||||
|
} |
||||||
|
} |
||||||
|
line, err := ss.ReadLine() |
||||||
|
if line != test.line { |
||||||
|
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) |
||||||
|
break |
||||||
|
} |
||||||
|
if err != test.err { |
||||||
|
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestPasswordNotSaved(t *testing.T) { |
||||||
|
c := &MockTerminal{ |
||||||
|
toSend: []byte("password\r\x1b[A\r"), |
||||||
|
bytesPerRead: 1, |
||||||
|
} |
||||||
|
ss := NewTerminal(c, "> ") |
||||||
|
pw, _ := ss.ReadPassword("> ") |
||||||
|
if pw != "password" { |
||||||
|
t.Fatalf("failed to read password, got %s", pw) |
||||||
|
} |
||||||
|
line, _ := ss.ReadLine() |
||||||
|
if len(line) > 0 { |
||||||
|
t.Fatalf("password was saved in history") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
var setSizeTests = []struct { |
||||||
|
width, height int |
||||||
|
}{ |
||||||
|
{40, 13}, |
||||||
|
{80, 24}, |
||||||
|
{132, 43}, |
||||||
|
} |
||||||
|
|
||||||
|
func TestTerminalSetSize(t *testing.T) { |
||||||
|
for _, setSize := range setSizeTests { |
||||||
|
c := &MockTerminal{ |
||||||
|
toSend: []byte("password\r\x1b[A\r"), |
||||||
|
bytesPerRead: 1, |
||||||
|
} |
||||||
|
ss := NewTerminal(c, "> ") |
||||||
|
ss.SetSize(setSize.width, setSize.height) |
||||||
|
pw, _ := ss.ReadPassword("Password: ") |
||||||
|
if pw != "password" { |
||||||
|
t.Fatalf("failed to read password, got %s", pw) |
||||||
|
} |
||||||
|
if string(c.received) != "Password: \r\n" { |
||||||
|
t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,128 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd
|
||||||
|
|
||||||
|
// Package terminal provides support functions for dealing with terminals, as
|
||||||
|
// commonly found on UNIX systems.
|
||||||
|
//
|
||||||
|
// Putting a terminal into raw mode is the most common requirement:
|
||||||
|
//
|
||||||
|
// oldState, err := terminal.MakeRaw(0)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
// defer terminal.Restore(0, oldState)
|
||||||
|
package terminal // import "golang.org/x/crypto/ssh/terminal"
|
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"syscall" |
||||||
|
"unsafe" |
||||||
|
) |
||||||
|
|
||||||
|
// State contains the state of a terminal.
|
||||||
|
type State struct { |
||||||
|
termios syscall.Termios |
||||||
|
} |
||||||
|
|
||||||
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
|
func IsTerminal(fd int) bool { |
||||||
|
var termios syscall.Termios |
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) |
||||||
|
return err == 0 |
||||||
|
} |
||||||
|
|
||||||
|
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||||
|
// mode and returns the previous state of the terminal so that it can be
|
||||||
|
// restored.
|
||||||
|
func MakeRaw(fd int) (*State, error) { |
||||||
|
var oldState State |
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
newState := oldState.termios |
||||||
|
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF |
||||||
|
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG |
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &oldState, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetState returns the current state of a terminal which may be useful to
|
||||||
|
// restore the terminal after a signal.
|
||||||
|
func GetState(fd int) (*State, error) { |
||||||
|
var oldState State |
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &oldState, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Restore restores the terminal connected to the given file descriptor to a
|
||||||
|
// previous state.
|
||||||
|
func Restore(fd int, state *State) error { |
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// GetSize returns the dimensions of the given terminal.
|
||||||
|
func GetSize(fd int) (width, height int, err error) { |
||||||
|
var dimensions [4]uint16 |
||||||
|
|
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { |
||||||
|
return -1, -1, err |
||||||
|
} |
||||||
|
return int(dimensions[1]), int(dimensions[0]), nil |
||||||
|
} |
||||||
|
|
||||||
|
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||||
|
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||||
|
// returned does not include the \n.
|
||||||
|
func ReadPassword(fd int) ([]byte, error) { |
||||||
|
var oldState syscall.Termios |
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
newState := oldState |
||||||
|
newState.Lflag &^= syscall.ECHO |
||||||
|
newState.Lflag |= syscall.ICANON | syscall.ISIG |
||||||
|
newState.Iflag |= syscall.ICRNL |
||||||
|
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
defer func() { |
||||||
|
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) |
||||||
|
}() |
||||||
|
|
||||||
|
var buf [16]byte |
||||||
|
var ret []byte |
||||||
|
for { |
||||||
|
n, err := syscall.Read(fd, buf[:]) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if n == 0 { |
||||||
|
if len(ret) == 0 { |
||||||
|
return nil, io.EOF |
||||||
|
} |
||||||
|
break |
||||||
|
} |
||||||
|
if buf[n-1] == '\n' { |
||||||
|
n-- |
||||||
|
} |
||||||
|
ret = append(ret, buf[:n]...) |
||||||
|
if n < len(buf) { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return ret, nil |
||||||
|
} |
@ -0,0 +1,12 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd netbsd openbsd
|
||||||
|
|
||||||
|
package terminal |
||||||
|
|
||||||
|
import "syscall" |
||||||
|
|
||||||
|
const ioctlReadTermios = syscall.TIOCGETA |
||||||
|
const ioctlWriteTermios = syscall.TIOCSETA |
@ -0,0 +1,11 @@ |
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package terminal |
||||||
|
|
||||||
|
// These constants are declared here, rather than importing
|
||||||
|
// them from the syscall package as some syscall packages, even
|
||||||
|
// on linux, for example gccgo, do not declare them.
|
||||||
|
const ioctlReadTermios = 0x5401 // syscall.TCGETS
|
||||||
|
const ioctlWriteTermios = 0x5402 // syscall.TCSETS
|
@ -0,0 +1,174 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
// Package terminal provides support functions for dealing with terminals, as
|
||||||
|
// commonly found on UNIX systems.
|
||||||
|
//
|
||||||
|
// Putting a terminal into raw mode is the most common requirement:
|
||||||
|
//
|
||||||
|
// oldState, err := terminal.MakeRaw(0)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
// defer terminal.Restore(0, oldState)
|
||||||
|
package terminal |
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"syscall" |
||||||
|
"unsafe" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
enableLineInput = 2 |
||||||
|
enableEchoInput = 4 |
||||||
|
enableProcessedInput = 1 |
||||||
|
enableWindowInput = 8 |
||||||
|
enableMouseInput = 16 |
||||||
|
enableInsertMode = 32 |
||||||
|
enableQuickEditMode = 64 |
||||||
|
enableExtendedFlags = 128 |
||||||
|
enableAutoPosition = 256 |
||||||
|
enableProcessedOutput = 1 |
||||||
|
enableWrapAtEolOutput = 2 |
||||||
|
) |
||||||
|
|
||||||
|
var kernel32 = syscall.NewLazyDLL("kernel32.dll") |
||||||
|
|
||||||
|
var ( |
||||||
|
procGetConsoleMode = kernel32.NewProc("GetConsoleMode") |
||||||
|
procSetConsoleMode = kernel32.NewProc("SetConsoleMode") |
||||||
|
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") |
||||||
|
) |
||||||
|
|
||||||
|
type ( |
||||||
|
short int16 |
||||||
|
word uint16 |
||||||
|
|
||||||
|
coord struct { |
||||||
|
x short |
||||||
|
y short |
||||||
|
} |
||||||
|
smallRect struct { |
||||||
|
left short |
||||||
|
top short |
||||||
|
right short |
||||||
|
bottom short |
||||||
|
} |
||||||
|
consoleScreenBufferInfo struct { |
||||||
|
size coord |
||||||
|
cursorPosition coord |
||||||
|
attributes word |
||||||
|
window smallRect |
||||||
|
maximumWindowSize coord |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
type State struct { |
||||||
|
mode uint32 |
||||||
|
} |
||||||
|
|
||||||
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
|
func IsTerminal(fd int) bool { |
||||||
|
var st uint32 |
||||||
|
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||||
|
return r != 0 && e == 0 |
||||||
|
} |
||||||
|
|
||||||
|
// MakeRaw put the terminal connected to the given file descriptor into raw
|
||||||
|
// mode and returns the previous state of the terminal so that it can be
|
||||||
|
// restored.
|
||||||
|
func MakeRaw(fd int) (*State, error) { |
||||||
|
var st uint32 |
||||||
|
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||||
|
if e != 0 { |
||||||
|
return nil, error(e) |
||||||
|
} |
||||||
|
st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) |
||||||
|
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) |
||||||
|
if e != 0 { |
||||||
|
return nil, error(e) |
||||||
|
} |
||||||
|
return &State{st}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetState returns the current state of a terminal which may be useful to
|
||||||
|
// restore the terminal after a signal.
|
||||||
|
func GetState(fd int) (*State, error) { |
||||||
|
var st uint32 |
||||||
|
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||||
|
if e != 0 { |
||||||
|
return nil, error(e) |
||||||
|
} |
||||||
|
return &State{st}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Restore restores the terminal connected to the given file descriptor to a
|
||||||
|
// previous state.
|
||||||
|
func Restore(fd int, state *State) error { |
||||||
|
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// GetSize returns the dimensions of the given terminal.
|
||||||
|
func GetSize(fd int) (width, height int, err error) { |
||||||
|
var info consoleScreenBufferInfo |
||||||
|
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) |
||||||
|
if e != 0 { |
||||||
|
return 0, 0, error(e) |
||||||
|
} |
||||||
|
return int(info.size.x), int(info.size.y), nil |
||||||
|
} |
||||||
|
|
||||||
|
// ReadPassword reads a line of input from a terminal without local echo. This
|
||||||
|
// is commonly used for inputting passwords and other sensitive data. The slice
|
||||||
|
// returned does not include the \n.
|
||||||
|
func ReadPassword(fd int) ([]byte, error) { |
||||||
|
var st uint32 |
||||||
|
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) |
||||||
|
if e != 0 { |
||||||
|
return nil, error(e) |
||||||
|
} |
||||||
|
old := st |
||||||
|
|
||||||
|
st &^= (enableEchoInput) |
||||||
|
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) |
||||||
|
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) |
||||||
|
if e != 0 { |
||||||
|
return nil, error(e) |
||||||
|
} |
||||||
|
|
||||||
|
defer func() { |
||||||
|
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) |
||||||
|
}() |
||||||
|
|
||||||
|
var buf [16]byte |
||||||
|
var ret []byte |
||||||
|
for { |
||||||
|
n, err := syscall.Read(syscall.Handle(fd), buf[:]) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if n == 0 { |
||||||
|
if len(ret) == 0 { |
||||||
|
return nil, io.EOF |
||||||
|
} |
||||||
|
break |
||||||
|
} |
||||||
|
if buf[n-1] == '\n' { |
||||||
|
n-- |
||||||
|
} |
||||||
|
if n > 0 && buf[n-1] == '\r' { |
||||||
|
n-- |
||||||
|
} |
||||||
|
ret = append(ret, buf[:n]...) |
||||||
|
if n < len(buf) { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return ret, nil |
||||||
|
} |
@ -0,0 +1,59 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd linux netbsd openbsd
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh" |
||||||
|
"golang.org/x/crypto/ssh/agent" |
||||||
|
) |
||||||
|
|
||||||
|
func TestAgentForward(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
keyring := agent.NewKeyring() |
||||||
|
if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil { |
||||||
|
t.Fatalf("Error adding key: %s", err) |
||||||
|
} |
||||||
|
if err := keyring.Add(agent.AddedKey{ |
||||||
|
PrivateKey: testPrivateKeys["dsa"], |
||||||
|
ConfirmBeforeUse: true, |
||||||
|
LifetimeSecs: 3600, |
||||||
|
}); err != nil { |
||||||
|
t.Fatalf("Error adding key with constraints: %s", err) |
||||||
|
} |
||||||
|
pub := testPublicKeys["dsa"] |
||||||
|
|
||||||
|
sess, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewSession: %v", err) |
||||||
|
} |
||||||
|
if err := agent.RequestAgentForwarding(sess); err != nil { |
||||||
|
t.Fatalf("RequestAgentForwarding: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if err := agent.ForwardToAgent(conn, keyring); err != nil { |
||||||
|
t.Fatalf("SetupForwardKeyring: %v", err) |
||||||
|
} |
||||||
|
out, err := sess.CombinedOutput("ssh-add -L") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("running ssh-add: %v, out %s", err, out) |
||||||
|
} |
||||||
|
key, _, _, _, err := ssh.ParseAuthorizedKey(out) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) |
||||||
|
} |
||||||
|
|
||||||
|
if !bytes.Equal(key.Marshal(), pub.Marshal()) { |
||||||
|
t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,47 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd linux netbsd openbsd
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
func TestCertLogin(t *testing.T) { |
||||||
|
s := newServer(t) |
||||||
|
defer s.Shutdown() |
||||||
|
|
||||||
|
// Use a key different from the default.
|
||||||
|
clientKey := testSigners["dsa"] |
||||||
|
caAuthKey := testSigners["ecdsa"] |
||||||
|
cert := &ssh.Certificate{ |
||||||
|
Key: clientKey.PublicKey(), |
||||||
|
ValidPrincipals: []string{username()}, |
||||||
|
CertType: ssh.UserCert, |
||||||
|
ValidBefore: ssh.CertTimeInfinity, |
||||||
|
} |
||||||
|
if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { |
||||||
|
t.Fatalf("SetSignature: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
certSigner, err := ssh.NewCertSigner(cert, clientKey) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("NewCertSigner: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
conf := &ssh.ClientConfig{ |
||||||
|
User: username(), |
||||||
|
} |
||||||
|
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) |
||||||
|
client, err := s.TryDial(conf) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("TryDial: %v", err) |
||||||
|
} |
||||||
|
client.Close() |
||||||
|
} |
@ -0,0 +1,7 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// This package contains integration tests for the
|
||||||
|
// golang.org/x/crypto/ssh package.
|
||||||
|
package test // import "golang.org/x/crypto/ssh/test"
|
@ -0,0 +1,160 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd linux netbsd openbsd
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"io" |
||||||
|
"io/ioutil" |
||||||
|
"math/rand" |
||||||
|
"net" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
func TestPortForward(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
sshListener, err := conn.Listen("tcp", "localhost:0") |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
sshConn, err := sshListener.Accept() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("listen.Accept failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
_, err = io.Copy(sshConn, sshConn) |
||||||
|
if err != nil && err != io.EOF { |
||||||
|
t.Fatalf("ssh client copy: %v", err) |
||||||
|
} |
||||||
|
sshConn.Close() |
||||||
|
}() |
||||||
|
|
||||||
|
forwardedAddr := sshListener.Addr().String() |
||||||
|
tcpConn, err := net.Dial("tcp", forwardedAddr) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("TCP dial failed: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
readChan := make(chan []byte) |
||||||
|
go func() { |
||||||
|
data, _ := ioutil.ReadAll(tcpConn) |
||||||
|
readChan <- data |
||||||
|
}() |
||||||
|
|
||||||
|
// Invent some data.
|
||||||
|
data := make([]byte, 100*1000) |
||||||
|
for i := range data { |
||||||
|
data[i] = byte(i % 255) |
||||||
|
} |
||||||
|
|
||||||
|
var sent []byte |
||||||
|
for len(sent) < 1000*1000 { |
||||||
|
// Send random sized chunks
|
||||||
|
m := rand.Intn(len(data)) |
||||||
|
n, err := tcpConn.Write(data[:m]) |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
sent = append(sent, data[:n]...) |
||||||
|
} |
||||||
|
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { |
||||||
|
t.Errorf("tcpConn.CloseWrite: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
read := <-readChan |
||||||
|
|
||||||
|
if len(sent) != len(read) { |
||||||
|
t.Fatalf("got %d bytes, want %d", len(read), len(sent)) |
||||||
|
} |
||||||
|
if bytes.Compare(sent, read) != 0 { |
||||||
|
t.Fatalf("read back data does not match") |
||||||
|
} |
||||||
|
|
||||||
|
if err := sshListener.Close(); err != nil { |
||||||
|
t.Fatalf("sshListener.Close: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Check that the forward disappeared.
|
||||||
|
tcpConn, err = net.Dial("tcp", forwardedAddr) |
||||||
|
if err == nil { |
||||||
|
tcpConn.Close() |
||||||
|
t.Errorf("still listening to %s after closing", forwardedAddr) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAcceptClose(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
|
||||||
|
sshListener, err := conn.Listen("tcp", "localhost:0") |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
quit := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
c, err := sshListener.Accept() |
||||||
|
if err != nil { |
||||||
|
quit <- err |
||||||
|
break |
||||||
|
} |
||||||
|
c.Close() |
||||||
|
} |
||||||
|
}() |
||||||
|
sshListener.Close() |
||||||
|
|
||||||
|
select { |
||||||
|
case <-time.After(1 * time.Second): |
||||||
|
t.Errorf("timeout: listener did not close.") |
||||||
|
case err := <-quit: |
||||||
|
t.Logf("quit as expected (error %v)", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Check that listeners exit if the underlying client transport dies.
|
||||||
|
func TestPortForwardConnectionClose(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
|
||||||
|
sshListener, err := conn.Listen("tcp", "localhost:0") |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
quit := make(chan error, 1) |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
c, err := sshListener.Accept() |
||||||
|
if err != nil { |
||||||
|
quit <- err |
||||||
|
break |
||||||
|
} |
||||||
|
c.Close() |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
// It would be even nicer if we closed the server side, but it
|
||||||
|
// is more involved as the fd for that side is dup()ed.
|
||||||
|
server.clientConn.Close() |
||||||
|
|
||||||
|
select { |
||||||
|
case <-time.After(1 * time.Second): |
||||||
|
t.Errorf("timeout: listener did not close.") |
||||||
|
case err := <-quit: |
||||||
|
t.Logf("quit as expected (error %v)", err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,340 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
// Session functional tests.
|
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh" |
||||||
|
) |
||||||
|
|
||||||
|
func TestRunCommandSuccess(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
err = session.Run("true") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestHostKeyCheck(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
|
||||||
|
conf := clientConfig() |
||||||
|
hostDB := hostKeyDB() |
||||||
|
conf.HostKeyCallback = hostDB.Check |
||||||
|
|
||||||
|
// change the keys.
|
||||||
|
hostDB.keys[ssh.KeyAlgoRSA][25]++ |
||||||
|
hostDB.keys[ssh.KeyAlgoDSA][25]++ |
||||||
|
hostDB.keys[ssh.KeyAlgoECDSA256][25]++ |
||||||
|
|
||||||
|
conn, err := server.TryDial(conf) |
||||||
|
if err == nil { |
||||||
|
conn.Close() |
||||||
|
t.Fatalf("dial should have failed.") |
||||||
|
} else if !strings.Contains(err.Error(), "host key mismatch") { |
||||||
|
t.Fatalf("'host key mismatch' not found in %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRunCommandStdin(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
r, w := io.Pipe() |
||||||
|
defer r.Close() |
||||||
|
defer w.Close() |
||||||
|
session.Stdin = r |
||||||
|
|
||||||
|
err = session.Run("true") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRunCommandStdinError(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
r, w := io.Pipe() |
||||||
|
defer r.Close() |
||||||
|
session.Stdin = r |
||||||
|
pipeErr := errors.New("closing write end of pipe") |
||||||
|
w.CloseWithError(pipeErr) |
||||||
|
|
||||||
|
err = session.Run("true") |
||||||
|
if err != pipeErr { |
||||||
|
t.Fatalf("expected %v, found %v", pipeErr, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRunCommandFailed(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
err = session.Run(`bash -c "kill -9 $$"`) |
||||||
|
if err == nil { |
||||||
|
t.Fatalf("session succeeded: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRunCommandWeClosed(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
err = session.Shell() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("shell failed: %v", err) |
||||||
|
} |
||||||
|
err = session.Close() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("shell failed: %v", err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestFuncLargeRead(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to create new session: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
stdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to acquire stdout pipe: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
err = session.Start("dd if=/dev/urandom bs=2048 count=1024") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to execute remote command: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
buf := new(bytes.Buffer) |
||||||
|
n, err := io.Copy(buf, stdout) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("error reading from remote stdout: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
if n != 2048*1024 { |
||||||
|
t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestKeyChange(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conf := clientConfig() |
||||||
|
hostDB := hostKeyDB() |
||||||
|
conf.HostKeyCallback = hostDB.Check |
||||||
|
conf.RekeyThreshold = 1024 |
||||||
|
conn := server.Dial(conf) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
for i := 0; i < 4; i++ { |
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to create new session: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
stdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to acquire stdout pipe: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
err = session.Start("dd if=/dev/urandom bs=1024 count=1") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to execute remote command: %s", err) |
||||||
|
} |
||||||
|
buf := new(bytes.Buffer) |
||||||
|
n, err := io.Copy(buf, stdout) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("error reading from remote stdout: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
want := int64(1024) |
||||||
|
if n != want { |
||||||
|
t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if changes := hostDB.checkCount; changes < 4 { |
||||||
|
t.Errorf("got %d key changes, want 4", changes) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestInvalidTerminalMode(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil { |
||||||
|
t.Fatalf("req-pty failed: successful request with invalid mode") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidTerminalMode(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conn := server.Dial(clientConfig()) |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
session, err := conn.NewSession() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %v", err) |
||||||
|
} |
||||||
|
defer session.Close() |
||||||
|
|
||||||
|
stdout, err := session.StdoutPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to acquire stdout pipe: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
stdin, err := session.StdinPipe() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("unable to acquire stdin pipe: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
tm := ssh.TerminalModes{ssh.ECHO: 0} |
||||||
|
if err = session.RequestPty("xterm", 80, 40, tm); err != nil { |
||||||
|
t.Fatalf("req-pty failed: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
err = session.Shell() |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("session failed: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
stdin.Write([]byte("stty -a && exit\n")) |
||||||
|
|
||||||
|
var buf bytes.Buffer |
||||||
|
if _, err := io.Copy(&buf, stdout); err != nil { |
||||||
|
t.Fatalf("reading failed: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") { |
||||||
|
t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCiphers(t *testing.T) { |
||||||
|
var config ssh.Config |
||||||
|
config.SetDefaults() |
||||||
|
cipherOrder := config.Ciphers |
||||||
|
// This cipher will not be tested when commented out in cipher.go it will
|
||||||
|
// fallback to the next available as per line 292.
|
||||||
|
cipherOrder = append(cipherOrder, "aes128-cbc") |
||||||
|
|
||||||
|
for _, ciph := range cipherOrder { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conf := clientConfig() |
||||||
|
conf.Ciphers = []string{ciph} |
||||||
|
// Don't fail if sshd doesnt have the cipher.
|
||||||
|
conf.Ciphers = append(conf.Ciphers, cipherOrder...) |
||||||
|
conn, err := server.TryDial(conf) |
||||||
|
if err == nil { |
||||||
|
conn.Close() |
||||||
|
} else { |
||||||
|
t.Fatalf("failed for cipher %q", ciph) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMACs(t *testing.T) { |
||||||
|
var config ssh.Config |
||||||
|
config.SetDefaults() |
||||||
|
macOrder := config.MACs |
||||||
|
|
||||||
|
for _, mac := range macOrder { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conf := clientConfig() |
||||||
|
conf.MACs = []string{mac} |
||||||
|
// Don't fail if sshd doesnt have the MAC.
|
||||||
|
conf.MACs = append(conf.MACs, macOrder...) |
||||||
|
if conn, err := server.TryDial(conf); err == nil { |
||||||
|
conn.Close() |
||||||
|
} else { |
||||||
|
t.Fatalf("failed for MAC %q", mac) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestKeyExchanges(t *testing.T) { |
||||||
|
var config ssh.Config |
||||||
|
config.SetDefaults() |
||||||
|
kexOrder := config.KeyExchanges |
||||||
|
for _, kex := range kexOrder { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
conf := clientConfig() |
||||||
|
// Don't fail if sshd doesnt have the kex.
|
||||||
|
conf.KeyExchanges = append([]string{kex}, kexOrder...) |
||||||
|
conn, err := server.TryDial(conf) |
||||||
|
if err == nil { |
||||||
|
conn.Close() |
||||||
|
} else { |
||||||
|
t.Errorf("failed for kex %q", kex) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,46 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
// direct-tcpip functional tests
|
||||||
|
|
||||||
|
import ( |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestDial(t *testing.T) { |
||||||
|
server := newServer(t) |
||||||
|
defer server.Shutdown() |
||||||
|
sshConn := server.Dial(clientConfig()) |
||||||
|
defer sshConn.Close() |
||||||
|
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0") |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Listen: %v", err) |
||||||
|
} |
||||||
|
defer l.Close() |
||||||
|
|
||||||
|
go func() { |
||||||
|
for { |
||||||
|
c, err := l.Accept() |
||||||
|
if err != nil { |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
io.WriteString(c, c.RemoteAddr().String()) |
||||||
|
c.Close() |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
conn, err := sshConn.Dial("tcp", l.Addr().String()) |
||||||
|
if err != nil { |
||||||
|
t.Fatalf("Dial: %v", err) |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
} |
@ -0,0 +1,261 @@ |
|||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build darwin dragonfly freebsd linux netbsd openbsd plan9
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
// functional test harness for unix.
|
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"fmt" |
||||||
|
"io/ioutil" |
||||||
|
"log" |
||||||
|
"net" |
||||||
|
"os" |
||||||
|
"os/exec" |
||||||
|
"os/user" |
||||||
|
"path/filepath" |
||||||
|
"testing" |
||||||
|
"text/template" |
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh" |
||||||
|
"golang.org/x/crypto/ssh/testdata" |
||||||
|
) |
||||||
|
|
||||||
|
const sshd_config = ` |
||||||
|
Protocol 2 |
||||||
|
HostKey {{.Dir}}/id_rsa |
||||||
|
HostKey {{.Dir}}/id_dsa |
||||||
|
HostKey {{.Dir}}/id_ecdsa |
||||||
|
Pidfile {{.Dir}}/sshd.pid |
||||||
|
#UsePrivilegeSeparation no |
||||||
|
KeyRegenerationInterval 3600 |
||||||
|
ServerKeyBits 768 |
||||||
|
SyslogFacility AUTH |
||||||
|
LogLevel DEBUG2 |
||||||
|
LoginGraceTime 120 |
||||||
|
PermitRootLogin no |
||||||
|
StrictModes no |
||||||
|
RSAAuthentication yes |
||||||
|
PubkeyAuthentication yes |
||||||
|
AuthorizedKeysFile {{.Dir}}/id_user.pub |
||||||
|
TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub |
||||||
|
IgnoreRhosts yes |
||||||
|
RhostsRSAAuthentication no |
||||||
|
HostbasedAuthentication no |
||||||
|
` |
||||||
|
|
||||||
|
var configTmpl = template.Must(template.New("").Parse(sshd_config)) |
||||||
|
|
||||||
|
type server struct { |
||||||
|
t *testing.T |
||||||
|
cleanup func() // executed during Shutdown
|
||||||
|
configfile string |
||||||
|
cmd *exec.Cmd |
||||||
|
output bytes.Buffer // holds stderr from sshd process
|
||||||
|
|
||||||
|
// Client half of the network connection.
|
||||||
|
clientConn net.Conn |
||||||
|
} |
||||||
|
|
||||||
|
func username() string { |
||||||
|
var username string |
||||||
|
if user, err := user.Current(); err == nil { |
||||||
|
username = user.Username |
||||||
|
} else { |
||||||
|
// user.Current() currently requires cgo. If an error is
|
||||||
|
// returned attempt to get the username from the environment.
|
||||||
|
log.Printf("user.Current: %v; falling back on $USER", err) |
||||||
|
username = os.Getenv("USER") |
||||||
|
} |
||||||
|
if username == "" { |
||||||
|
panic("Unable to get username") |
||||||
|
} |
||||||
|
return username |
||||||
|
} |
||||||
|
|
||||||
|
type storedHostKey struct { |
||||||
|
// keys map from an algorithm string to binary key data.
|
||||||
|
keys map[string][]byte |
||||||
|
|
||||||
|
// checkCount counts the Check calls. Used for testing
|
||||||
|
// rekeying.
|
||||||
|
checkCount int |
||||||
|
} |
||||||
|
|
||||||
|
func (k *storedHostKey) Add(key ssh.PublicKey) { |
||||||
|
if k.keys == nil { |
||||||
|
k.keys = map[string][]byte{} |
||||||
|
} |
||||||
|
k.keys[key.Type()] = key.Marshal() |
||||||
|
} |
||||||
|
|
||||||
|
func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { |
||||||
|
k.checkCount++ |
||||||
|
algo := key.Type() |
||||||
|
|
||||||
|
if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { |
||||||
|
return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func hostKeyDB() *storedHostKey { |
||||||
|
keyChecker := &storedHostKey{} |
||||||
|
keyChecker.Add(testPublicKeys["ecdsa"]) |
||||||
|
keyChecker.Add(testPublicKeys["rsa"]) |
||||||
|
keyChecker.Add(testPublicKeys["dsa"]) |
||||||
|
return keyChecker |
||||||
|
} |
||||||
|
|
||||||
|
func clientConfig() *ssh.ClientConfig { |
||||||
|
config := &ssh.ClientConfig{ |
||||||
|
User: username(), |
||||||
|
Auth: []ssh.AuthMethod{ |
||||||
|
ssh.PublicKeys(testSigners["user"]), |
||||||
|
}, |
||||||
|
HostKeyCallback: hostKeyDB().Check, |
||||||
|
} |
||||||
|
return config |
||||||
|
} |
||||||
|
|
||||||
|
// unixConnection creates two halves of a connected net.UnixConn. It
|
||||||
|
// is used for connecting the Go SSH client with sshd without opening
|
||||||
|
// ports.
|
||||||
|
func unixConnection() (*net.UnixConn, *net.UnixConn, error) { |
||||||
|
dir, err := ioutil.TempDir("", "unixConnection") |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
defer os.Remove(dir) |
||||||
|
|
||||||
|
addr := filepath.Join(dir, "ssh") |
||||||
|
listener, err := net.Listen("unix", addr) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
defer listener.Close() |
||||||
|
c1, err := net.Dial("unix", addr) |
||||||
|
if err != nil { |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
c2, err := listener.Accept() |
||||||
|
if err != nil { |
||||||
|
c1.Close() |
||||||
|
return nil, nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return c1.(*net.UnixConn), c2.(*net.UnixConn), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { |
||||||
|
sshd, err := exec.LookPath("sshd") |
||||||
|
if err != nil { |
||||||
|
s.t.Skipf("skipping test: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
c1, c2, err := unixConnection() |
||||||
|
if err != nil { |
||||||
|
s.t.Fatalf("unixConnection: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") |
||||||
|
f, err := c2.File() |
||||||
|
if err != nil { |
||||||
|
s.t.Fatalf("UnixConn.File: %v", err) |
||||||
|
} |
||||||
|
defer f.Close() |
||||||
|
s.cmd.Stdin = f |
||||||
|
s.cmd.Stdout = f |
||||||
|
s.cmd.Stderr = &s.output |
||||||
|
if err := s.cmd.Start(); err != nil { |
||||||
|
s.t.Fail() |
||||||
|
s.Shutdown() |
||||||
|
s.t.Fatalf("s.cmd.Start: %v", err) |
||||||
|
} |
||||||
|
s.clientConn = c1 |
||||||
|
conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return ssh.NewClient(conn, chans, reqs), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { |
||||||
|
conn, err := s.TryDial(config) |
||||||
|
if err != nil { |
||||||
|
s.t.Fail() |
||||||
|
s.Shutdown() |
||||||
|
s.t.Fatalf("ssh.Client: %v", err) |
||||||
|
} |
||||||
|
return conn |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Shutdown() { |
||||||
|
if s.cmd != nil && s.cmd.Process != nil { |
||||||
|
// Don't check for errors; if it fails it's most
|
||||||
|
// likely "os: process already finished", and we don't
|
||||||
|
// care about that. Use os.Interrupt, so child
|
||||||
|
// processes are killed too.
|
||||||
|
s.cmd.Process.Signal(os.Interrupt) |
||||||
|
s.cmd.Wait() |
||||||
|
} |
||||||
|
if s.t.Failed() { |
||||||
|
// log any output from sshd process
|
||||||
|
s.t.Logf("sshd: %s", s.output.String()) |
||||||
|
} |
||||||
|
s.cleanup() |
||||||
|
} |
||||||
|
|
||||||
|
func writeFile(path string, contents []byte) { |
||||||
|
f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) |
||||||
|
if err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
defer f.Close() |
||||||
|
if _, err := f.Write(contents); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// newServer returns a new mock ssh server.
|
||||||
|
func newServer(t *testing.T) *server { |
||||||
|
if testing.Short() { |
||||||
|
t.Skip("skipping test due to -short") |
||||||
|
} |
||||||
|
dir, err := ioutil.TempDir("", "sshtest") |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
f, err := os.Create(filepath.Join(dir, "sshd_config")) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
err = configTmpl.Execute(f, map[string]string{ |
||||||
|
"Dir": dir, |
||||||
|
}) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
f.Close() |
||||||
|
|
||||||
|
for k, v := range testdata.PEMBytes { |
||||||
|
filename := "id_" + k |
||||||
|
writeFile(filepath.Join(dir, filename), v) |
||||||
|
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) |
||||||
|
} |
||||||
|
|
||||||
|
return &server{ |
||||||
|
t: t, |
||||||
|
configfile: f.Name(), |
||||||
|
cleanup: func() { |
||||||
|
if err := os.RemoveAll(dir); err != nil { |
||||||
|
t.Error(err) |
||||||
|
} |
||||||
|
}, |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,64 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
||||||
|
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
||||||
|
// instances.
|
||||||
|
|
||||||
|
package test |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh" |
||||||
|
"golang.org/x/crypto/ssh/testdata" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
testPrivateKeys map[string]interface{} |
||||||
|
testSigners map[string]ssh.Signer |
||||||
|
testPublicKeys map[string]ssh.PublicKey |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
var err error |
||||||
|
|
||||||
|
n := len(testdata.PEMBytes) |
||||||
|
testPrivateKeys = make(map[string]interface{}, n) |
||||||
|
testSigners = make(map[string]ssh.Signer, n) |
||||||
|
testPublicKeys = make(map[string]ssh.PublicKey, n) |
||||||
|
for t, k := range testdata.PEMBytes { |
||||||
|
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testPublicKeys[t] = testSigners[t].PublicKey() |
||||||
|
} |
||||||
|
|
||||||
|
// Create a cert and sign it for use in tests.
|
||||||
|
testCert := &ssh.Certificate{ |
||||||
|
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
||||||
|
ValidAfter: 0, // unix epoch
|
||||||
|
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
||||||
|
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
Key: testPublicKeys["ecdsa"], |
||||||
|
SignatureKey: testPublicKeys["rsa"], |
||||||
|
Permissions: ssh.Permissions{ |
||||||
|
CriticalOptions: map[string]string{}, |
||||||
|
Extensions: map[string]string{}, |
||||||
|
}, |
||||||
|
} |
||||||
|
testCert.SignCert(rand.Reader, testSigners["rsa"]) |
||||||
|
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] |
||||||
|
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,8 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// This package contains test data shared between the various subpackages of
|
||||||
|
// the golang.org/x/crypto/ssh package. Under no circumstance should
|
||||||
|
// this data be used for production code.
|
||||||
|
package testdata // import "golang.org/x/crypto/ssh/testdata"
|
@ -0,0 +1,43 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package testdata |
||||||
|
|
||||||
|
var PEMBytes = map[string][]byte{ |
||||||
|
"dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- |
||||||
|
MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB |
||||||
|
lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
|
||||||
|
EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD |
||||||
|
nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV |
||||||
|
2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r |
||||||
|
juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr |
||||||
|
FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz |
||||||
|
DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj |
||||||
|
nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY |
||||||
|
Fmsr0W6fHB9nhS4/UXM8 |
||||||
|
-----END DSA PRIVATE KEY----- |
||||||
|
`), |
||||||
|
"ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- |
||||||
|
MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 |
||||||
|
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ |
||||||
|
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== |
||||||
|
-----END EC PRIVATE KEY----- |
||||||
|
`), |
||||||
|
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- |
||||||
|
MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld |
||||||
|
r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ |
||||||
|
tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC |
||||||
|
nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW |
||||||
|
2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB |
||||||
|
y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr |
||||||
|
rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg== |
||||||
|
-----END RSA PRIVATE KEY----- |
||||||
|
`), |
||||||
|
"user": []byte(`-----BEGIN EC PRIVATE KEY----- |
||||||
|
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 |
||||||
|
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD |
||||||
|
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== |
||||||
|
-----END EC PRIVATE KEY----- |
||||||
|
`), |
||||||
|
} |
@ -0,0 +1,63 @@ |
|||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
||||||
|
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
||||||
|
// instances.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rand" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/gogits/gogs/modules/ssh/testdata" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
testPrivateKeys map[string]interface{} |
||||||
|
testSigners map[string]Signer |
||||||
|
testPublicKeys map[string]PublicKey |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
var err error |
||||||
|
|
||||||
|
n := len(testdata.PEMBytes) |
||||||
|
testPrivateKeys = make(map[string]interface{}, n) |
||||||
|
testSigners = make(map[string]Signer, n) |
||||||
|
testPublicKeys = make(map[string]PublicKey, n) |
||||||
|
for t, k := range testdata.PEMBytes { |
||||||
|
testPrivateKeys[t], err = ParseRawPrivateKey(k) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) |
||||||
|
} |
||||||
|
testPublicKeys[t] = testSigners[t].PublicKey() |
||||||
|
} |
||||||
|
|
||||||
|
// Create a cert and sign it for use in tests.
|
||||||
|
testCert := &Certificate{ |
||||||
|
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
||||||
|
ValidAfter: 0, // unix epoch
|
||||||
|
ValidBefore: CertTimeInfinity, // The end of currently representable time.
|
||||||
|
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||||
|
Key: testPublicKeys["ecdsa"], |
||||||
|
SignatureKey: testPublicKeys["rsa"], |
||||||
|
Permissions: Permissions{ |
||||||
|
CriticalOptions: map[string]string{}, |
||||||
|
Extensions: map[string]string{}, |
||||||
|
}, |
||||||
|
} |
||||||
|
testCert.SignCert(rand.Reader, testSigners["rsa"]) |
||||||
|
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] |
||||||
|
testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) |
||||||
|
if err != nil { |
||||||
|
panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,332 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
gcmCipherID = "aes128-gcm@openssh.com" |
||||||
|
aes128cbcID = "aes128-cbc" |
||||||
|
) |
||||||
|
|
||||||
|
// packetConn represents a transport that implements packet based
|
||||||
|
// operations.
|
||||||
|
type packetConn interface { |
||||||
|
// Encrypt and send a packet of data to the remote peer.
|
||||||
|
writePacket(packet []byte) error |
||||||
|
|
||||||
|
// Read a packet from the connection
|
||||||
|
readPacket() ([]byte, error) |
||||||
|
|
||||||
|
// Close closes the write-side of the connection.
|
||||||
|
Close() error |
||||||
|
} |
||||||
|
|
||||||
|
// transport is the keyingTransport that implements the SSH packet
|
||||||
|
// protocol.
|
||||||
|
type transport struct { |
||||||
|
reader connectionState |
||||||
|
writer connectionState |
||||||
|
|
||||||
|
bufReader *bufio.Reader |
||||||
|
bufWriter *bufio.Writer |
||||||
|
rand io.Reader |
||||||
|
|
||||||
|
io.Closer |
||||||
|
|
||||||
|
// Initial H used for the session ID. Once assigned this does
|
||||||
|
// not change, even during subsequent key exchanges.
|
||||||
|
sessionID []byte |
||||||
|
} |
||||||
|
|
||||||
|
// getSessionID returns the ID of the SSH connection. The return value
|
||||||
|
// should not be modified.
|
||||||
|
func (t *transport) getSessionID() []byte { |
||||||
|
if t.sessionID == nil { |
||||||
|
panic("session ID not set yet") |
||||||
|
} |
||||||
|
return t.sessionID |
||||||
|
} |
||||||
|
|
||||||
|
// packetCipher represents a combination of SSH encryption/MAC
|
||||||
|
// protocol. A single instance should be used for one direction only.
|
||||||
|
type packetCipher interface { |
||||||
|
// writePacket encrypts the packet and writes it to w. The
|
||||||
|
// contents of the packet are generally scrambled.
|
||||||
|
writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error |
||||||
|
|
||||||
|
// readPacket reads and decrypts a packet of data. The
|
||||||
|
// returned packet may be overwritten by future calls of
|
||||||
|
// readPacket.
|
||||||
|
readPacket(seqnum uint32, r io.Reader) ([]byte, error) |
||||||
|
} |
||||||
|
|
||||||
|
// connectionState represents one side (read or write) of the
|
||||||
|
// connection. This is necessary because each direction has its own
|
||||||
|
// keys, and can even have its own algorithms
|
||||||
|
type connectionState struct { |
||||||
|
packetCipher |
||||||
|
seqNum uint32 |
||||||
|
dir direction |
||||||
|
pendingKeyChange chan packetCipher |
||||||
|
} |
||||||
|
|
||||||
|
// prepareKeyChange sets up key material for a keychange. The key changes in
|
||||||
|
// both directions are triggered by reading and writing a msgNewKey packet
|
||||||
|
// respectively.
|
||||||
|
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { |
||||||
|
if t.sessionID == nil { |
||||||
|
t.sessionID = kexResult.H |
||||||
|
} |
||||||
|
|
||||||
|
kexResult.SessionID = t.sessionID |
||||||
|
|
||||||
|
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { |
||||||
|
return err |
||||||
|
} else { |
||||||
|
t.reader.pendingKeyChange <- ciph |
||||||
|
} |
||||||
|
|
||||||
|
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { |
||||||
|
return err |
||||||
|
} else { |
||||||
|
t.writer.pendingKeyChange <- ciph |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Read and decrypt next packet.
|
||||||
|
func (t *transport) readPacket() ([]byte, error) { |
||||||
|
return t.reader.readPacket(t.bufReader) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { |
||||||
|
packet, err := s.packetCipher.readPacket(s.seqNum, r) |
||||||
|
s.seqNum++ |
||||||
|
if err == nil && len(packet) == 0 { |
||||||
|
err = errors.New("ssh: zero length packet") |
||||||
|
} |
||||||
|
|
||||||
|
if len(packet) > 0 && packet[0] == msgNewKeys { |
||||||
|
select { |
||||||
|
case cipher := <-s.pendingKeyChange: |
||||||
|
s.packetCipher = cipher |
||||||
|
default: |
||||||
|
return nil, errors.New("ssh: got bogus newkeys message.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// The packet may point to an internal buffer, so copy the
|
||||||
|
// packet out here.
|
||||||
|
fresh := make([]byte, len(packet)) |
||||||
|
copy(fresh, packet) |
||||||
|
|
||||||
|
return fresh, err |
||||||
|
} |
||||||
|
|
||||||
|
func (t *transport) writePacket(packet []byte) error { |
||||||
|
return t.writer.writePacket(t.bufWriter, t.rand, packet) |
||||||
|
} |
||||||
|
|
||||||
|
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { |
||||||
|
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys |
||||||
|
|
||||||
|
err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if err = w.Flush(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
s.seqNum++ |
||||||
|
if changeKeys { |
||||||
|
select { |
||||||
|
case cipher := <-s.pendingKeyChange: |
||||||
|
s.packetCipher = cipher |
||||||
|
default: |
||||||
|
panic("ssh: no key material for msgNewKeys") |
||||||
|
} |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { |
||||||
|
t := &transport{ |
||||||
|
bufReader: bufio.NewReader(rwc), |
||||||
|
bufWriter: bufio.NewWriter(rwc), |
||||||
|
rand: rand, |
||||||
|
reader: connectionState{ |
||||||
|
packetCipher: &streamPacketCipher{cipher: noneCipher{}}, |
||||||
|
pendingKeyChange: make(chan packetCipher, 1), |
||||||
|
}, |
||||||
|
writer: connectionState{ |
||||||
|
packetCipher: &streamPacketCipher{cipher: noneCipher{}}, |
||||||
|
pendingKeyChange: make(chan packetCipher, 1), |
||||||
|
}, |
||||||
|
Closer: rwc, |
||||||
|
} |
||||||
|
if isClient { |
||||||
|
t.reader.dir = serverKeys |
||||||
|
t.writer.dir = clientKeys |
||||||
|
} else { |
||||||
|
t.reader.dir = clientKeys |
||||||
|
t.writer.dir = serverKeys |
||||||
|
} |
||||||
|
|
||||||
|
return t |
||||||
|
} |
||||||
|
|
||||||
|
type direction struct { |
||||||
|
ivTag []byte |
||||||
|
keyTag []byte |
||||||
|
macKeyTag []byte |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} |
||||||
|
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} |
||||||
|
) |
||||||
|
|
||||||
|
// generateKeys generates key material for IV, MAC and encryption.
|
||||||
|
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { |
||||||
|
cipherMode := cipherModes[algs.Cipher] |
||||||
|
macMode := macModes[algs.MAC] |
||||||
|
|
||||||
|
iv = make([]byte, cipherMode.ivSize) |
||||||
|
key = make([]byte, cipherMode.keySize) |
||||||
|
macKey = make([]byte, macMode.keySize) |
||||||
|
|
||||||
|
generateKeyMaterial(iv, d.ivTag, kex) |
||||||
|
generateKeyMaterial(key, d.keyTag, kex) |
||||||
|
generateKeyMaterial(macKey, d.macKeyTag, kex) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
||||||
|
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
||||||
|
// (to setup server->client keys) or clientKeys (for client->server keys).
|
||||||
|
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { |
||||||
|
iv, key, macKey := generateKeys(d, algs, kex) |
||||||
|
|
||||||
|
if algs.Cipher == gcmCipherID { |
||||||
|
return newGCMCipher(iv, key, macKey) |
||||||
|
} |
||||||
|
|
||||||
|
if algs.Cipher == aes128cbcID { |
||||||
|
return newAESCBCCipher(iv, key, macKey, algs) |
||||||
|
} |
||||||
|
|
||||||
|
c := &streamPacketCipher{ |
||||||
|
mac: macModes[algs.MAC].new(macKey), |
||||||
|
} |
||||||
|
c.macResult = make([]byte, c.mac.Size()) |
||||||
|
|
||||||
|
var err error |
||||||
|
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
// generateKeyMaterial fills out with key material generated from tag, K, H
|
||||||
|
// and sessionId, as specified in RFC 4253, section 7.2.
|
||||||
|
func generateKeyMaterial(out, tag []byte, r *kexResult) { |
||||||
|
var digestsSoFar []byte |
||||||
|
|
||||||
|
h := r.Hash.New() |
||||||
|
for len(out) > 0 { |
||||||
|
h.Reset() |
||||||
|
h.Write(r.K) |
||||||
|
h.Write(r.H) |
||||||
|
|
||||||
|
if len(digestsSoFar) == 0 { |
||||||
|
h.Write(tag) |
||||||
|
h.Write(r.SessionID) |
||||||
|
} else { |
||||||
|
h.Write(digestsSoFar) |
||||||
|
} |
||||||
|
|
||||||
|
digest := h.Sum(nil) |
||||||
|
n := copy(out, digest) |
||||||
|
out = out[n:] |
||||||
|
if len(out) > 0 { |
||||||
|
digestsSoFar = append(digestsSoFar, digest...) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
const packageVersion = "SSH-2.0-Go" |
||||||
|
|
||||||
|
// Sends and receives a version line. The versionLine string should
|
||||||
|
// be US ASCII, start with "SSH-2.0-", and should not include a
|
||||||
|
// newline. exchangeVersions returns the other side's version line.
|
||||||
|
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { |
||||||
|
// Contrary to the RFC, we do not ignore lines that don't
|
||||||
|
// start with "SSH-2.0-" to make the library usable with
|
||||||
|
// nonconforming servers.
|
||||||
|
for _, c := range versionLine { |
||||||
|
// The spec disallows non US-ASCII chars, and
|
||||||
|
// specifically forbids null chars.
|
||||||
|
if c < 32 { |
||||||
|
return nil, errors.New("ssh: junk character in version line") |
||||||
|
} |
||||||
|
} |
||||||
|
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
them, err = readVersion(rw) |
||||||
|
return them, err |
||||||
|
} |
||||||
|
|
||||||
|
// maxVersionStringBytes is the maximum number of bytes that we'll
|
||||||
|
// accept as a version string. RFC 4253 section 4.2 limits this at 255
|
||||||
|
// chars
|
||||||
|
const maxVersionStringBytes = 255 |
||||||
|
|
||||||
|
// Read version string as specified by RFC 4253, section 4.2.
|
||||||
|
func readVersion(r io.Reader) ([]byte, error) { |
||||||
|
versionString := make([]byte, 0, 64) |
||||||
|
var ok bool |
||||||
|
var buf [1]byte |
||||||
|
|
||||||
|
for len(versionString) < maxVersionStringBytes { |
||||||
|
_, err := io.ReadFull(r, buf[:]) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
// The RFC says that the version should be terminated with \r\n
|
||||||
|
// but several SSH servers actually only send a \n.
|
||||||
|
if buf[0] == '\n' { |
||||||
|
ok = true |
||||||
|
break |
||||||
|
} |
||||||
|
|
||||||
|
// non ASCII chars are disallowed, but we are lenient,
|
||||||
|
// since Go doesn't use null-terminated strings.
|
||||||
|
|
||||||
|
// The RFC allows a comment after a space, however,
|
||||||
|
// all of it (version and comments) goes into the
|
||||||
|
// session hash.
|
||||||
|
versionString = append(versionString, buf[0]) |
||||||
|
} |
||||||
|
|
||||||
|
if !ok { |
||||||
|
return nil, errors.New("ssh: overflow reading version string") |
||||||
|
} |
||||||
|
|
||||||
|
// There might be a '\r' on the end which we should remove.
|
||||||
|
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { |
||||||
|
versionString = versionString[:len(versionString)-1] |
||||||
|
} |
||||||
|
return versionString, nil |
||||||
|
} |
@ -0,0 +1,109 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"crypto/rand" |
||||||
|
"encoding/binary" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
) |
||||||
|
|
||||||
|
func TestReadVersion(t *testing.T) { |
||||||
|
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] |
||||||
|
cases := map[string]string{ |
||||||
|
"SSH-2.0-bla\r\n": "SSH-2.0-bla", |
||||||
|
"SSH-2.0-bla\n": "SSH-2.0-bla", |
||||||
|
longversion + "\r\n": longversion, |
||||||
|
} |
||||||
|
|
||||||
|
for in, want := range cases { |
||||||
|
result, err := readVersion(bytes.NewBufferString(in)) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("readVersion(%q): %s", in, err) |
||||||
|
} |
||||||
|
got := string(result) |
||||||
|
if got != want { |
||||||
|
t.Errorf("got %q, want %q", got, want) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestReadVersionError(t *testing.T) { |
||||||
|
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] |
||||||
|
cases := []string{ |
||||||
|
longversion + "too-long\r\n", |
||||||
|
} |
||||||
|
for _, in := range cases { |
||||||
|
if _, err := readVersion(bytes.NewBufferString(in)); err == nil { |
||||||
|
t.Errorf("readVersion(%q) should have failed", in) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestExchangeVersionsBasic(t *testing.T) { |
||||||
|
v := "SSH-2.0-bla" |
||||||
|
buf := bytes.NewBufferString(v + "\r\n") |
||||||
|
them, err := exchangeVersions(buf, []byte("xyz")) |
||||||
|
if err != nil { |
||||||
|
t.Errorf("exchangeVersions: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
if want := "SSH-2.0-bla"; string(them) != want { |
||||||
|
t.Errorf("got %q want %q for our version", them, want) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestExchangeVersions(t *testing.T) { |
||||||
|
cases := []string{ |
||||||
|
"not\x000allowed", |
||||||
|
"not allowed\n", |
||||||
|
} |
||||||
|
for _, c := range cases { |
||||||
|
buf := bytes.NewBufferString("SSH-2.0-bla\r\n") |
||||||
|
if _, err := exchangeVersions(buf, []byte(c)); err == nil { |
||||||
|
t.Errorf("exchangeVersions(%q): should have failed", c) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type closerBuffer struct { |
||||||
|
bytes.Buffer |
||||||
|
} |
||||||
|
|
||||||
|
func (b *closerBuffer) Close() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func TestTransportMaxPacketWrite(t *testing.T) { |
||||||
|
buf := &closerBuffer{} |
||||||
|
tr := newTransport(buf, rand.Reader, true) |
||||||
|
huge := make([]byte, maxPacket+1) |
||||||
|
err := tr.writePacket(huge) |
||||||
|
if err == nil { |
||||||
|
t.Errorf("transport accepted write for a huge packet.") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestTransportMaxPacketReader(t *testing.T) { |
||||||
|
var header [5]byte |
||||||
|
huge := make([]byte, maxPacket+128) |
||||||
|
binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) |
||||||
|
// padding.
|
||||||
|
header[4] = 0 |
||||||
|
|
||||||
|
buf := &closerBuffer{} |
||||||
|
buf.Write(header[:]) |
||||||
|
buf.Write(huge) |
||||||
|
|
||||||
|
tr := newTransport(buf, rand.Reader, true) |
||||||
|
_, err := tr.readPacket() |
||||||
|
if err == nil { |
||||||
|
t.Errorf("transport succeeded reading huge packet.") |
||||||
|
} else if !strings.Contains(err.Error(), "large") { |
||||||
|
t.Errorf("got %q, should mention %q", err.Error(), "large") |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue