Merge branch 'develop' of https://github.com/gogits/gogs into developtokarchuk/v1.17
commit
737da1a374
@ -0,0 +1,35 @@ |
||||
LDFLAGS += -X "github.com/gogits/gogs/modules/setting.BuildTime=$(shell date -u '+%Y-%m-%d %I:%M:%S %Z')"
|
||||
LDFLAGS += -X "github.com/gogits/gogs/modules/setting.BuildGitHash=$(shell git rev-parse HEAD)"
|
||||
|
||||
TAGS = ""
|
||||
|
||||
RELEASE_ROOT = "release"
|
||||
RELEASE_GOGS = "release/gogs"
|
||||
NOW = $(shell date -u '+%Y%m%d%I%M%S')
|
||||
|
||||
.PHONY: build pack release bindata clean |
||||
|
||||
build: |
||||
go install -ldflags '$(LDFLAGS)' -tags '$(TAGS)'
|
||||
go build -ldflags '$(LDFLAGS)' -tags '$(TAGS)'
|
||||
|
||||
govet: |
||||
go tool vet -composites=false -methods=false -structtags=false .
|
||||
|
||||
pack: |
||||
rm -rf $(RELEASE_GOGS)
|
||||
mkdir -p $(RELEASE_GOGS)
|
||||
cp -r gogs LICENSE README.md README_ZH.md templates public scripts $(RELEASE_GOGS)
|
||||
rm -rf $(RELEASE_GOGS)/public/config.codekit $(RELEASE_GOGS)/public/less
|
||||
cd $(RELEASE_ROOT) && zip -r gogs.$(NOW).zip "gogs"
|
||||
|
||||
release: build pack |
||||
|
||||
bindata: |
||||
go-bindata -o=modules/bindata/bindata.go -ignore="\\.DS_Store|README.md" -pkg=bindata conf/...
|
||||
|
||||
clean: |
||||
go clean -i ./...
|
||||
|
||||
clean-mac: clean |
||||
find . -name ".DS_Store" -print0 | xargs -0 rm
|
@ -0,0 +1,42 @@ |
||||
// Copyright 2015 The Gogs Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/codegangsta/cli" |
||||
) |
||||
|
||||
func stringFlag(name, value, usage string) cli.StringFlag { |
||||
return cli.StringFlag{ |
||||
Name: name, |
||||
Value: value, |
||||
Usage: usage, |
||||
} |
||||
} |
||||
|
||||
func boolFlag(name, usage string) cli.BoolFlag { |
||||
return cli.BoolFlag{ |
||||
Name: name, |
||||
Usage: usage, |
||||
} |
||||
} |
||||
|
||||
func intFlag(name string, value int, usage string) cli.IntFlag { |
||||
return cli.IntFlag{ |
||||
Name: name, |
||||
Value: value, |
||||
Usage: usage, |
||||
} |
||||
} |
||||
|
||||
func durationFlag(name string, value time.Duration, usage string) cli.DurationFlag { |
||||
return cli.DurationFlag{ |
||||
Name: name, |
||||
Value: value, |
||||
Usage: usage, |
||||
} |
||||
} |
@ -1,20 +1,29 @@ |
||||
# This file lists all PUBLIC individuals having contributed content to the translation. |
||||
# Order of name is meaningless. |
||||
# Entries are in alphabetical order. |
||||
|
||||
Akihiro YAGASAKI <yaggytter@momiage.com> |
||||
Alexander Steinhöfer <kontakt@lx-s.de> |
||||
Alexandre Magno <alexandre.mbm@gmail.com> |
||||
Barış Arda Yılmaz <ardayilmazgamer@gmail.com> |
||||
Christoph Kisfeld <christoph.kisfeld@gmail.com> |
||||
Daniel Speichert <daniel@speichert.pl> |
||||
Gregor Santner <gdev@live.de> |
||||
Huimin Wang <wanghm2009@hotmail.co.jp> |
||||
ilko <email> |
||||
Thomas Fanninger <gogs.thomas@fanninger.at> |
||||
Łukasz Jan Niemier <lukasz@niemier.pl> |
||||
Lafriks <lafriks@gmail.com> |
||||
Luc Stepniewski <luc@stepniewski.fr> |
||||
Miguel de la Cruz <miguel@mcrx.me> |
||||
Marc Schiller <marc@schiller.im> |
||||
Morten Sørensen <klim8d@gmail.com> |
||||
Natan Albuquerque <natanalbuquerque5@gmail.com> |
||||
Akihiro YAGASAKI <yaggytter AT momiage DOT com> |
||||
Alexander Steinhöfer <kontakt AT lx-s DOT de> |
||||
Alexandre Magno <alexandre DOT mbm AT gmail DOT com> |
||||
Andrey Nering <andrey AT nering DOT com DOT br> |
||||
Arthur Aslanyan <arthur DOT e DOT aslanyan AT gmail DOT com> |
||||
Barış Arda Yılmaz <ardayilmazgamer AT gmail DOT com> |
||||
Christoph Kisfeld <christoph DOT kisfeld AT gmail DOT com> |
||||
Daniel Speichert <daniel AT speichert DOT pl> |
||||
Dmitriy Nogay <me AT catwhocode DOT ga> |
||||
Gregor Santner <gdev AT live DOT de> |
||||
Hamid Feizabadi <hamidfzm AT gmail DOT com> |
||||
Huimin Wang <wanghm2009 AT hotmail DOT co DOT jp> |
||||
ilko |
||||
Lafriks <lafriks AT gmail DOT com> |
||||
Lauri Ojansivu <x AT xet7 DOT org> |
||||
Luc Stepniewski <luc AT stepniewski DOT fr> |
||||
Marc Schiller <marc AT schiller DOT im> |
||||
Miguel de la Cruz <miguel AT mcrx DOT me> |
||||
Morten Sørensen <klim8d AT gmail DOT com> |
||||
Natan Albuquerque <natanalbuquerque5 AT gmail DOT com> |
||||
Odilon Junior <odilon DOT junior93 AT gmail DOT com> |
||||
Thomas Fanninger <gogs DOT thomas AT fanninger DOT at> |
||||
Tilmann Bach <tilmann AT outlook DOT com> |
||||
Vladimir Vissoultchev <wqweto AT gmail DOT com> |
||||
YJSoft <yjsoft AT yjsoft DOT pe DOT kr> |
||||
Łukasz Jan Niemier <lukasz AT niemier DOT pl> |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,12 +0,0 @@ |
||||
web: |
||||
build: . |
||||
links: |
||||
- mysql |
||||
ports: |
||||
- "3000:3000" |
||||
|
||||
mysql: |
||||
image: mysql |
||||
environment: |
||||
- MYSQL_ROOT_PASSWORD=gogs |
||||
- MYSQL_DATABASE=gogs |
@ -0,0 +1,7 @@ |
||||
#!/bin/sh |
||||
|
||||
if test -f ./setup; then |
||||
source ./setup |
||||
fi |
||||
|
||||
exec gosu root /sbin/syslogd -nS -O- |
File diff suppressed because one or more lines are too long
@ -1,615 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* |
||||
Package agent implements a client to an ssh-agent daemon. |
||||
|
||||
References: |
||||
[PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD
|
||||
*/ |
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rsa" |
||||
"encoding/base64" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// Agent represents the capabilities of an ssh-agent.
|
||||
type Agent interface { |
||||
// List returns the identities known to the agent.
|
||||
List() ([]*Key, error) |
||||
|
||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||
// in [PROTOCOL.agent] section 2.6.2.
|
||||
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) |
||||
|
||||
// Add adds a private key to the agent.
|
||||
Add(key AddedKey) error |
||||
|
||||
// Remove removes all identities with the given public key.
|
||||
Remove(key ssh.PublicKey) error |
||||
|
||||
// RemoveAll removes all identities.
|
||||
RemoveAll() error |
||||
|
||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||
Lock(passphrase []byte) error |
||||
|
||||
// Unlock undoes the effect of Lock
|
||||
Unlock(passphrase []byte) error |
||||
|
||||
// Signers returns signers for all the known keys.
|
||||
Signers() ([]ssh.Signer, error) |
||||
} |
||||
|
||||
// AddedKey describes an SSH key to be added to an Agent.
|
||||
type AddedKey struct { |
||||
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
|
||||
// *ecdsa.PrivateKey, which will be inserted into the agent.
|
||||
PrivateKey interface{} |
||||
// Certificate, if not nil, is communicated to the agent and will be
|
||||
// stored with the key.
|
||||
Certificate *ssh.Certificate |
||||
// Comment is an optional, free-form string.
|
||||
Comment string |
||||
// LifetimeSecs, if not zero, is the number of seconds that the
|
||||
// agent will store the key for.
|
||||
LifetimeSecs uint32 |
||||
// ConfirmBeforeUse, if true, requests that the agent confirm with the
|
||||
// user before each use of this key.
|
||||
ConfirmBeforeUse bool |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 3.
|
||||
const ( |
||||
agentRequestV1Identities = 1 |
||||
|
||||
// 3.2 Requests from client to agent for protocol 2 key operations
|
||||
agentAddIdentity = 17 |
||||
agentRemoveIdentity = 18 |
||||
agentRemoveAllIdentities = 19 |
||||
agentAddIdConstrained = 25 |
||||
|
||||
// 3.3 Key-type independent requests from client to agent
|
||||
agentAddSmartcardKey = 20 |
||||
agentRemoveSmartcardKey = 21 |
||||
agentLock = 22 |
||||
agentUnlock = 23 |
||||
agentAddSmartcardKeyConstrained = 26 |
||||
|
||||
// 3.7 Key constraint identifiers
|
||||
agentConstrainLifetime = 1 |
||||
agentConstrainConfirm = 2 |
||||
) |
||||
|
||||
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
||||
// is a sanity check, not a limit in the spec.
|
||||
const maxAgentResponseBytes = 16 << 20 |
||||
|
||||
// Agent messages:
|
||||
// These structures mirror the wire format of the corresponding ssh agent
|
||||
// messages found in [PROTOCOL.agent].
|
||||
|
||||
// 3.4 Generic replies from agent to client
|
||||
const agentFailure = 5 |
||||
|
||||
type failureAgentMsg struct{} |
||||
|
||||
const agentSuccess = 6 |
||||
|
||||
type successAgentMsg struct{} |
||||
|
||||
// See [PROTOCOL.agent], section 2.5.2.
|
||||
const agentRequestIdentities = 11 |
||||
|
||||
type requestIdentitiesAgentMsg struct{} |
||||
|
||||
// See [PROTOCOL.agent], section 2.5.2.
|
||||
const agentIdentitiesAnswer = 12 |
||||
|
||||
type identitiesAnswerAgentMsg struct { |
||||
NumKeys uint32 `sshtype:"12"` |
||||
Keys []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 2.6.2.
|
||||
const agentSignRequest = 13 |
||||
|
||||
type signRequestAgentMsg struct { |
||||
KeyBlob []byte `sshtype:"13"` |
||||
Data []byte |
||||
Flags uint32 |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 2.6.2.
|
||||
|
||||
// 3.6 Replies from agent to client for protocol 2 key operations
|
||||
const agentSignResponse = 14 |
||||
|
||||
type signResponseAgentMsg struct { |
||||
SigBlob []byte `sshtype:"14"` |
||||
} |
||||
|
||||
type publicKey struct { |
||||
Format string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Key represents a protocol 2 public key as defined in
|
||||
// [PROTOCOL.agent], section 2.5.2.
|
||||
type Key struct { |
||||
Format string |
||||
Blob []byte |
||||
Comment string |
||||
} |
||||
|
||||
func clientErr(err error) error { |
||||
return fmt.Errorf("agent: client error: %v", err) |
||||
} |
||||
|
||||
// String returns the storage form of an agent key with the format, base64
|
||||
// encoded serialized key, and the comment if it is not empty.
|
||||
func (k *Key) String() string { |
||||
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) |
||||
|
||||
if k.Comment != "" { |
||||
s += " " + k.Comment |
||||
} |
||||
|
||||
return s |
||||
} |
||||
|
||||
// Type returns the public key type.
|
||||
func (k *Key) Type() string { |
||||
return k.Format |
||||
} |
||||
|
||||
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
|
||||
func (k *Key) Marshal() []byte { |
||||
return k.Blob |
||||
} |
||||
|
||||
// Verify satisfies the ssh.PublicKey interface, but is not
|
||||
// implemented for agent keys.
|
||||
func (k *Key) Verify(data []byte, sig *ssh.Signature) error { |
||||
return errors.New("agent: agent key does not know how to verify") |
||||
} |
||||
|
||||
type wireKey struct { |
||||
Format string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
func parseKey(in []byte) (out *Key, rest []byte, err error) { |
||||
var record struct { |
||||
Blob []byte |
||||
Comment string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := ssh.Unmarshal(in, &record); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(record.Blob, &wk); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return &Key{ |
||||
Format: wk.Format, |
||||
Blob: record.Blob, |
||||
Comment: record.Comment, |
||||
}, record.Rest, nil |
||||
} |
||||
|
||||
// client is a client for an ssh-agent process.
|
||||
type client struct { |
||||
// conn is typically a *net.UnixConn
|
||||
conn io.ReadWriter |
||||
// mu is used to prevent concurrent access to the agent
|
||||
mu sync.Mutex |
||||
} |
||||
|
||||
// NewClient returns an Agent that talks to an ssh-agent process over
|
||||
// the given connection.
|
||||
func NewClient(rw io.ReadWriter) Agent { |
||||
return &client{conn: rw} |
||||
} |
||||
|
||||
// call sends an RPC to the agent. On success, the reply is
|
||||
// unmarshaled into reply and replyType is set to the first byte of
|
||||
// the reply, which contains the type of the message.
|
||||
func (c *client) call(req []byte) (reply interface{}, err error) { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
|
||||
msg := make([]byte, 4+len(req)) |
||||
binary.BigEndian.PutUint32(msg, uint32(len(req))) |
||||
copy(msg[4:], req) |
||||
if _, err = c.conn.Write(msg); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
|
||||
var respSizeBuf [4]byte |
||||
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
respSize := binary.BigEndian.Uint32(respSizeBuf[:]) |
||||
if respSize > maxAgentResponseBytes { |
||||
return nil, clientErr(err) |
||||
} |
||||
|
||||
buf := make([]byte, respSize) |
||||
if _, err = io.ReadFull(c.conn, buf); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
reply, err = unmarshal(buf) |
||||
if err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
return reply, err |
||||
} |
||||
|
||||
func (c *client) simpleCall(req []byte) error { |
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
func (c *client) RemoveAll() error { |
||||
return c.simpleCall([]byte{agentRemoveAllIdentities}) |
||||
} |
||||
|
||||
func (c *client) Remove(key ssh.PublicKey) error { |
||||
req := ssh.Marshal(&agentRemoveIdentityMsg{ |
||||
KeyBlob: key.Marshal(), |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
func (c *client) Lock(passphrase []byte) error { |
||||
req := ssh.Marshal(&agentLockMsg{ |
||||
Passphrase: passphrase, |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
func (c *client) Unlock(passphrase []byte) error { |
||||
req := ssh.Marshal(&agentUnlockMsg{ |
||||
Passphrase: passphrase, |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
// List returns the identities known to the agent.
|
||||
func (c *client) List() ([]*Key, error) { |
||||
// see [PROTOCOL.agent] section 2.5.2.
|
||||
req := []byte{agentRequestIdentities} |
||||
|
||||
msg, err := c.call(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *identitiesAnswerAgentMsg: |
||||
if msg.NumKeys > maxAgentResponseBytes/8 { |
||||
return nil, errors.New("agent: too many keys in agent reply") |
||||
} |
||||
keys := make([]*Key, msg.NumKeys) |
||||
data := msg.Keys |
||||
for i := uint32(0); i < msg.NumKeys; i++ { |
||||
var key *Key |
||||
var err error |
||||
if key, data, err = parseKey(data); err != nil { |
||||
return nil, err |
||||
} |
||||
keys[i] = key |
||||
} |
||||
return keys, nil |
||||
case *failureAgentMsg: |
||||
return nil, errors.New("agent: failed to list keys") |
||||
} |
||||
panic("unreachable") |
||||
} |
||||
|
||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||
// in [PROTOCOL.agent] section 2.6.2.
|
||||
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||
req := ssh.Marshal(signRequestAgentMsg{ |
||||
KeyBlob: key.Marshal(), |
||||
Data: data, |
||||
}) |
||||
|
||||
msg, err := c.call(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *signResponseAgentMsg: |
||||
var sig ssh.Signature |
||||
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &sig, nil |
||||
case *failureAgentMsg: |
||||
return nil, errors.New("agent: failed to sign challenge") |
||||
} |
||||
panic("unreachable") |
||||
} |
||||
|
||||
// unmarshal parses an agent message in packet, returning the parsed
|
||||
// form and the message type of packet.
|
||||
func unmarshal(packet []byte) (interface{}, error) { |
||||
if len(packet) < 1 { |
||||
return nil, errors.New("agent: empty packet") |
||||
} |
||||
var msg interface{} |
||||
switch packet[0] { |
||||
case agentFailure: |
||||
return new(failureAgentMsg), nil |
||||
case agentSuccess: |
||||
return new(successAgentMsg), nil |
||||
case agentIdentitiesAnswer: |
||||
msg = new(identitiesAnswerAgentMsg) |
||||
case agentSignResponse: |
||||
msg = new(signResponseAgentMsg) |
||||
default: |
||||
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) |
||||
} |
||||
if err := ssh.Unmarshal(packet, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return msg, nil |
||||
} |
||||
|
||||
type rsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
N *big.Int |
||||
E *big.Int |
||||
D *big.Int |
||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||
P *big.Int |
||||
Q *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type dsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
P *big.Int |
||||
Q *big.Int |
||||
G *big.Int |
||||
Y *big.Int |
||||
X *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type ecdsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
Curve string |
||||
KeyBytes []byte |
||||
D *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Insert adds a private key to the agent.
|
||||
func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { |
||||
var req []byte |
||||
switch k := s.(type) { |
||||
case *rsa.PrivateKey: |
||||
if len(k.Primes) != 2 { |
||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||
} |
||||
k.Precompute() |
||||
req = ssh.Marshal(rsaKeyMsg{ |
||||
Type: ssh.KeyAlgoRSA, |
||||
N: k.N, |
||||
E: big.NewInt(int64(k.E)), |
||||
D: k.D, |
||||
Iqmp: k.Precomputed.Qinv, |
||||
P: k.Primes[0], |
||||
Q: k.Primes[1], |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *dsa.PrivateKey: |
||||
req = ssh.Marshal(dsaKeyMsg{ |
||||
Type: ssh.KeyAlgoDSA, |
||||
P: k.P, |
||||
Q: k.Q, |
||||
G: k.G, |
||||
Y: k.Y, |
||||
X: k.X, |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *ecdsa.PrivateKey: |
||||
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) |
||||
req = ssh.Marshal(ecdsaKeyMsg{ |
||||
Type: "ecdsa-sha2-" + nistID, |
||||
Curve: nistID, |
||||
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), |
||||
D: k.D, |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
default: |
||||
return fmt.Errorf("agent: unsupported key type %T", s) |
||||
} |
||||
|
||||
// if constraints are present then the message type needs to be changed.
|
||||
if len(constraints) != 0 { |
||||
req[0] = agentAddIdConstrained |
||||
} |
||||
|
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
type rsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
D *big.Int |
||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||
P *big.Int |
||||
Q *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type dsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
X *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type ecdsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
D *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Insert adds a private key to the agent. If a certificate is given,
|
||||
// that certificate is added instead as public key.
|
||||
func (c *client) Add(key AddedKey) error { |
||||
var constraints []byte |
||||
|
||||
if secs := key.LifetimeSecs; secs != 0 { |
||||
constraints = append(constraints, agentConstrainLifetime) |
||||
|
||||
var secsBytes [4]byte |
||||
binary.BigEndian.PutUint32(secsBytes[:], secs) |
||||
constraints = append(constraints, secsBytes[:]...) |
||||
} |
||||
|
||||
if key.ConfirmBeforeUse { |
||||
constraints = append(constraints, agentConstrainConfirm) |
||||
} |
||||
|
||||
if cert := key.Certificate; cert == nil { |
||||
return c.insertKey(key.PrivateKey, key.Comment, constraints) |
||||
} else { |
||||
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) |
||||
} |
||||
} |
||||
|
||||
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { |
||||
var req []byte |
||||
switch k := s.(type) { |
||||
case *rsa.PrivateKey: |
||||
if len(k.Primes) != 2 { |
||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||
} |
||||
k.Precompute() |
||||
req = ssh.Marshal(rsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
D: k.D, |
||||
Iqmp: k.Precomputed.Qinv, |
||||
P: k.Primes[0], |
||||
Q: k.Primes[1], |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *dsa.PrivateKey: |
||||
req = ssh.Marshal(dsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
X: k.X, |
||||
Comments: comment, |
||||
}) |
||||
case *ecdsa.PrivateKey: |
||||
req = ssh.Marshal(ecdsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
D: k.D, |
||||
Comments: comment, |
||||
}) |
||||
default: |
||||
return fmt.Errorf("agent: unsupported key type %T", s) |
||||
} |
||||
|
||||
// if constraints are present then the message type needs to be changed.
|
||||
if len(constraints) != 0 { |
||||
req[0] = agentAddIdConstrained |
||||
} |
||||
|
||||
signer, err := ssh.NewSignerFromKey(s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||
return errors.New("agent: signer and cert have different public key") |
||||
} |
||||
|
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
// Signers provides a callback for client authentication.
|
||||
func (c *client) Signers() ([]ssh.Signer, error) { |
||||
keys, err := c.List() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var result []ssh.Signer |
||||
for _, k := range keys { |
||||
result = append(result, &agentKeyringSigner{c, k}) |
||||
} |
||||
return result, nil |
||||
} |
||||
|
||||
type agentKeyringSigner struct { |
||||
agent *client |
||||
pub ssh.PublicKey |
||||
} |
||||
|
||||
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { |
||||
return s.pub |
||||
} |
||||
|
||||
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { |
||||
// The agent has its own entropy source, so the rand argument is ignored.
|
||||
return s.agent.Sign(s.pub, data) |
||||
} |
@ -1,287 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"net" |
||||
"os" |
||||
"os/exec" |
||||
"path/filepath" |
||||
"strconv" |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
||||
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { |
||||
if testing.Short() { |
||||
// ssh-agent is not always available, and the key
|
||||
// types supported vary by platform.
|
||||
t.Skip("skipping test due to -short") |
||||
} |
||||
|
||||
bin, err := exec.LookPath("ssh-agent") |
||||
if err != nil { |
||||
t.Skip("could not find ssh-agent") |
||||
} |
||||
|
||||
cmd := exec.Command(bin, "-s") |
||||
out, err := cmd.Output() |
||||
if err != nil { |
||||
t.Fatalf("cmd.Output: %v", err) |
||||
} |
||||
|
||||
/* Output looks like: |
||||
|
||||
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; |
||||
SSH_AGENT_PID=15542; export SSH_AGENT_PID; |
||||
echo Agent pid 15542; |
||||
*/ |
||||
fields := bytes.Split(out, []byte(";")) |
||||
line := bytes.SplitN(fields[0], []byte("="), 2) |
||||
line[0] = bytes.TrimLeft(line[0], "\n") |
||||
if string(line[0]) != "SSH_AUTH_SOCK" { |
||||
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) |
||||
} |
||||
socket = string(line[1]) |
||||
|
||||
line = bytes.SplitN(fields[2], []byte("="), 2) |
||||
line[0] = bytes.TrimLeft(line[0], "\n") |
||||
if string(line[0]) != "SSH_AGENT_PID" { |
||||
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) |
||||
} |
||||
pidStr := line[1] |
||||
pid, err := strconv.Atoi(string(pidStr)) |
||||
if err != nil { |
||||
t.Fatalf("Atoi(%q): %v", pidStr, err) |
||||
} |
||||
|
||||
conn, err := net.Dial("unix", string(socket)) |
||||
if err != nil { |
||||
t.Fatalf("net.Dial: %v", err) |
||||
} |
||||
|
||||
ac := NewClient(conn) |
||||
return ac, socket, func() { |
||||
proc, _ := os.FindProcess(pid) |
||||
if proc != nil { |
||||
proc.Kill() |
||||
} |
||||
conn.Close() |
||||
os.RemoveAll(filepath.Dir(socket)) |
||||
} |
||||
} |
||||
|
||||
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
testAgentInterface(t, agent, key, cert, lifetimeSecs) |
||||
} |
||||
|
||||
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||
signer, err := ssh.NewSignerFromKey(key) |
||||
if err != nil { |
||||
t.Fatalf("NewSignerFromKey(%T): %v", key, err) |
||||
} |
||||
// The agent should start up empty.
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Fatalf("RequestIdentities: %v", err) |
||||
} else if len(keys) > 0 { |
||||
t.Fatalf("got %d keys, want 0: %v", len(keys), keys) |
||||
} |
||||
|
||||
// Attempt to insert the key, with certificate if specified.
|
||||
var pubKey ssh.PublicKey |
||||
if cert != nil { |
||||
err = agent.Add(AddedKey{ |
||||
PrivateKey: key, |
||||
Certificate: cert, |
||||
Comment: "comment", |
||||
LifetimeSecs: lifetimeSecs, |
||||
}) |
||||
pubKey = cert |
||||
} else { |
||||
err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) |
||||
pubKey = signer.PublicKey() |
||||
} |
||||
if err != nil { |
||||
t.Fatalf("insert(%T): %v", key, err) |
||||
} |
||||
|
||||
// Did the key get inserted successfully?
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Fatalf("List: %v", err) |
||||
} else if len(keys) != 1 { |
||||
t.Fatalf("got %v, want 1 key", keys) |
||||
} else if keys[0].Comment != "comment" { |
||||
t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") |
||||
} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { |
||||
t.Fatalf("key mismatch") |
||||
} |
||||
|
||||
// Can the agent make a valid signature?
|
||||
data := []byte("hello") |
||||
sig, err := agent.Sign(pubKey, data) |
||||
if err != nil { |
||||
t.Fatalf("Sign(%s): %v", pubKey.Type(), err) |
||||
} |
||||
|
||||
if err := pubKey.Verify(data, sig); err != nil { |
||||
t.Fatalf("Verify(%s): %v", pubKey.Type(), err) |
||||
} |
||||
} |
||||
|
||||
func TestAgent(t *testing.T) { |
||||
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { |
||||
testAgent(t, testPrivateKeys[keyType], nil, 0) |
||||
} |
||||
} |
||||
|
||||
func TestCert(t *testing.T) { |
||||
cert := &ssh.Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: ssh.CertTimeInfinity, |
||||
CertType: ssh.UserCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
testAgent(t, testPrivateKeys["rsa"], cert, 0) |
||||
} |
||||
|
||||
func TestConstraints(t *testing.T) { |
||||
testAgent(t, testPrivateKeys["rsa"], nil, 3600 /* lifetime in seconds */) |
||||
} |
||||
|
||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||
// a write.)
|
||||
func netPipe() (net.Conn, net.Conn, error) { |
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
defer listener.Close() |
||||
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
c2, err := listener.Accept() |
||||
if err != nil { |
||||
c1.Close() |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c1, c2, nil |
||||
} |
||||
|
||||
func TestAuth(t *testing.T) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
|
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
|
||||
serverConf := ssh.ServerConfig{} |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
||||
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||
return nil, nil |
||||
} |
||||
|
||||
return nil, errors.New("pubkey rejected") |
||||
} |
||||
|
||||
go func() { |
||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||
if err != nil { |
||||
t.Fatalf("Server: %v", err) |
||||
} |
||||
conn.Close() |
||||
}() |
||||
|
||||
conf := ssh.ClientConfig{} |
||||
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) |
||||
conn, _, _, err := ssh.NewClientConn(b, "", &conf) |
||||
if err != nil { |
||||
t.Fatalf("NewClientConn: %v", err) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestLockClient(t *testing.T) { |
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
testLockAgent(agent, t) |
||||
} |
||||
|
||||
func testLockAgent(agent Agent, t *testing.T) { |
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 2 { |
||||
t.Errorf("Want 2 keys, got %v", keys) |
||||
} |
||||
|
||||
passphrase := []byte("secret") |
||||
if err := agent.Lock(passphrase); err != nil { |
||||
t.Errorf("Lock: %v", err) |
||||
} |
||||
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 0 { |
||||
t.Errorf("Want 0 keys, got %v", keys) |
||||
} |
||||
|
||||
signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) |
||||
if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { |
||||
t.Fatalf("Sign did not fail") |
||||
} |
||||
|
||||
if err := agent.Remove(signer.PublicKey()); err == nil { |
||||
t.Fatalf("Remove did not fail") |
||||
} |
||||
|
||||
if err := agent.RemoveAll(); err == nil { |
||||
t.Fatalf("RemoveAll did not fail") |
||||
} |
||||
|
||||
if err := agent.Unlock(nil); err == nil { |
||||
t.Errorf("Unlock with wrong passphrase succeeded") |
||||
} |
||||
if err := agent.Unlock(passphrase); err != nil { |
||||
t.Errorf("Unlock: %v", err) |
||||
} |
||||
|
||||
if err := agent.Remove(signer.PublicKey()); err != nil { |
||||
t.Fatalf("Remove: %v", err) |
||||
} |
||||
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 1 { |
||||
t.Errorf("Want 1 keys, got %v", keys) |
||||
} |
||||
} |
@ -1,103 +0,0 @@ |
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// RequestAgentForwarding sets up agent forwarding for the session.
|
||||
// ForwardToAgent or ForwardToRemote should be called to route
|
||||
// the authentication requests.
|
||||
func RequestAgentForwarding(session *ssh.Session) error { |
||||
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !ok { |
||||
return errors.New("forwarding request denied") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ForwardToAgent routes authentication requests to the given keyring.
|
||||
func ForwardToAgent(client *ssh.Client, keyring Agent) error { |
||||
channels := client.HandleChannelOpen(channelType) |
||||
if channels == nil { |
||||
return errors.New("agent: already have handler for " + channelType) |
||||
} |
||||
|
||||
go func() { |
||||
for ch := range channels { |
||||
channel, reqs, err := ch.Accept() |
||||
if err != nil { |
||||
continue |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
go func() { |
||||
ServeAgent(keyring, channel) |
||||
channel.Close() |
||||
}() |
||||
} |
||||
}() |
||||
return nil |
||||
} |
||||
|
||||
const channelType = "auth-agent@openssh.com" |
||||
|
||||
// ForwardToRemote routes authentication requests to the ssh-agent
|
||||
// process serving on the given unix socket.
|
||||
func ForwardToRemote(client *ssh.Client, addr string) error { |
||||
channels := client.HandleChannelOpen(channelType) |
||||
if channels == nil { |
||||
return errors.New("agent: already have handler for " + channelType) |
||||
} |
||||
conn, err := net.Dial("unix", addr) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
conn.Close() |
||||
|
||||
go func() { |
||||
for ch := range channels { |
||||
channel, reqs, err := ch.Accept() |
||||
if err != nil { |
||||
continue |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
go forwardUnixSocket(channel, addr) |
||||
} |
||||
}() |
||||
return nil |
||||
} |
||||
|
||||
func forwardUnixSocket(channel ssh.Channel, addr string) { |
||||
conn, err := net.Dial("unix", addr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
var wg sync.WaitGroup |
||||
wg.Add(2) |
||||
go func() { |
||||
io.Copy(conn, channel) |
||||
conn.(*net.UnixConn).CloseWrite() |
||||
wg.Done() |
||||
}() |
||||
go func() { |
||||
io.Copy(channel, conn) |
||||
channel.CloseWrite() |
||||
wg.Done() |
||||
}() |
||||
|
||||
wg.Wait() |
||||
conn.Close() |
||||
channel.Close() |
||||
} |
@ -1,184 +0,0 @@ |
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"crypto/subtle" |
||||
"errors" |
||||
"fmt" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
type privKey struct { |
||||
signer ssh.Signer |
||||
comment string |
||||
} |
||||
|
||||
type keyring struct { |
||||
mu sync.Mutex |
||||
keys []privKey |
||||
|
||||
locked bool |
||||
passphrase []byte |
||||
} |
||||
|
||||
var errLocked = errors.New("agent: locked") |
||||
|
||||
// NewKeyring returns an Agent that holds keys in memory. It is safe
|
||||
// for concurrent use by multiple goroutines.
|
||||
func NewKeyring() Agent { |
||||
return &keyring{} |
||||
} |
||||
|
||||
// RemoveAll removes all identities.
|
||||
func (r *keyring) RemoveAll() error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
r.keys = nil |
||||
return nil |
||||
} |
||||
|
||||
// Remove removes all identities with the given public key.
|
||||
func (r *keyring) Remove(key ssh.PublicKey) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
want := key.Marshal() |
||||
found := false |
||||
for i := 0; i < len(r.keys); { |
||||
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { |
||||
found = true |
||||
r.keys[i] = r.keys[len(r.keys)-1] |
||||
r.keys = r.keys[len(r.keys)-1:] |
||||
continue |
||||
} else { |
||||
i++ |
||||
} |
||||
} |
||||
|
||||
if !found { |
||||
return errors.New("agent: key not found") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||
func (r *keyring) Lock(passphrase []byte) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
r.locked = true |
||||
r.passphrase = passphrase |
||||
return nil |
||||
} |
||||
|
||||
// Unlock undoes the effect of Lock
|
||||
func (r *keyring) Unlock(passphrase []byte) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if !r.locked { |
||||
return errors.New("agent: not locked") |
||||
} |
||||
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { |
||||
return fmt.Errorf("agent: incorrect passphrase") |
||||
} |
||||
|
||||
r.locked = false |
||||
r.passphrase = nil |
||||
return nil |
||||
} |
||||
|
||||
// List returns the identities known to the agent.
|
||||
func (r *keyring) List() ([]*Key, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
// section 2.7: locked agents return empty.
|
||||
return nil, nil |
||||
} |
||||
|
||||
var ids []*Key |
||||
for _, k := range r.keys { |
||||
pub := k.signer.PublicKey() |
||||
ids = append(ids, &Key{ |
||||
Format: pub.Type(), |
||||
Blob: pub.Marshal(), |
||||
Comment: k.comment}) |
||||
} |
||||
return ids, nil |
||||
} |
||||
|
||||
// Insert adds a private key to the keyring. If a certificate
|
||||
// is given, that certificate is added as public key. Note that
|
||||
// any constraints given are ignored.
|
||||
func (r *keyring) Add(key AddedKey) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
signer, err := ssh.NewSignerFromKey(key.PrivateKey) |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if cert := key.Certificate; cert != nil { |
||||
signer, err = ssh.NewCertSigner(cert, signer) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
r.keys = append(r.keys, privKey{signer, key.Comment}) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Sign returns a signature for the data.
|
||||
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return nil, errLocked |
||||
} |
||||
|
||||
wanted := key.Marshal() |
||||
for _, k := range r.keys { |
||||
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { |
||||
return k.signer.Sign(rand.Reader, data) |
||||
} |
||||
} |
||||
return nil, errors.New("not found") |
||||
} |
||||
|
||||
// Signers returns signers for all the known keys.
|
||||
func (r *keyring) Signers() ([]ssh.Signer, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return nil, errLocked |
||||
} |
||||
|
||||
s := make([]ssh.Signer, 0, len(r.keys)) |
||||
for _, k := range r.keys { |
||||
s = append(s, k.signer) |
||||
} |
||||
return s, nil |
||||
} |
@ -1,209 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"crypto/rsa" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"math/big" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// Server wraps an Agent and uses it to implement the agent side of
|
||||
// the SSH-agent, wire protocol.
|
||||
type server struct { |
||||
agent Agent |
||||
} |
||||
|
||||
func (s *server) processRequestBytes(reqData []byte) []byte { |
||||
rep, err := s.processRequest(reqData) |
||||
if err != nil { |
||||
if err != errLocked { |
||||
// TODO(hanwen): provide better logging interface?
|
||||
log.Printf("agent %d: %v", reqData[0], err) |
||||
} |
||||
return []byte{agentFailure} |
||||
} |
||||
|
||||
if err == nil && rep == nil { |
||||
return []byte{agentSuccess} |
||||
} |
||||
|
||||
return ssh.Marshal(rep) |
||||
} |
||||
|
||||
func marshalKey(k *Key) []byte { |
||||
var record struct { |
||||
Blob []byte |
||||
Comment string |
||||
} |
||||
record.Blob = k.Marshal() |
||||
record.Comment = k.Comment |
||||
|
||||
return ssh.Marshal(&record) |
||||
} |
||||
|
||||
type agentV1IdentityMsg struct { |
||||
Numkeys uint32 `sshtype:"2"` |
||||
} |
||||
|
||||
type agentRemoveIdentityMsg struct { |
||||
KeyBlob []byte `sshtype:"18"` |
||||
} |
||||
|
||||
type agentLockMsg struct { |
||||
Passphrase []byte `sshtype:"22"` |
||||
} |
||||
|
||||
type agentUnlockMsg struct { |
||||
Passphrase []byte `sshtype:"23"` |
||||
} |
||||
|
||||
func (s *server) processRequest(data []byte) (interface{}, error) { |
||||
switch data[0] { |
||||
case agentRequestV1Identities: |
||||
return &agentV1IdentityMsg{0}, nil |
||||
case agentRemoveIdentity: |
||||
var req agentRemoveIdentityMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) |
||||
|
||||
case agentRemoveAllIdentities: |
||||
return nil, s.agent.RemoveAll() |
||||
|
||||
case agentLock: |
||||
var req agentLockMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return nil, s.agent.Lock(req.Passphrase) |
||||
|
||||
case agentUnlock: |
||||
var req agentLockMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, s.agent.Unlock(req.Passphrase) |
||||
|
||||
case agentSignRequest: |
||||
var req signRequestAgentMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
k := &Key{ |
||||
Format: wk.Format, |
||||
Blob: req.KeyBlob, |
||||
} |
||||
|
||||
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil |
||||
case agentRequestIdentities: |
||||
keys, err := s.agent.List() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
rep := identitiesAnswerAgentMsg{ |
||||
NumKeys: uint32(len(keys)), |
||||
} |
||||
for _, k := range keys { |
||||
rep.Keys = append(rep.Keys, marshalKey(k)...) |
||||
} |
||||
return rep, nil |
||||
case agentAddIdentity: |
||||
return nil, s.insertIdentity(data) |
||||
} |
||||
|
||||
return nil, fmt.Errorf("unknown opcode %d", data[0]) |
||||
} |
||||
|
||||
func (s *server) insertIdentity(req []byte) error { |
||||
var record struct { |
||||
Type string `sshtype:"17"` |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := ssh.Unmarshal(req, &record); err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch record.Type { |
||||
case ssh.KeyAlgoRSA: |
||||
var k rsaKeyMsg |
||||
if err := ssh.Unmarshal(req, &k); err != nil { |
||||
return err |
||||
} |
||||
|
||||
priv := rsa.PrivateKey{ |
||||
PublicKey: rsa.PublicKey{ |
||||
E: int(k.E.Int64()), |
||||
N: k.N, |
||||
}, |
||||
D: k.D, |
||||
Primes: []*big.Int{k.P, k.Q}, |
||||
} |
||||
priv.Precompute() |
||||
|
||||
return s.agent.Add(AddedKey{PrivateKey: &priv, Comment: k.Comments}) |
||||
} |
||||
return fmt.Errorf("not implemented: %s", record.Type) |
||||
} |
||||
|
||||
// ServeAgent serves the agent protocol on the given connection. It
|
||||
// returns when an I/O error occurs.
|
||||
func ServeAgent(agent Agent, c io.ReadWriter) error { |
||||
s := &server{agent} |
||||
|
||||
var length [4]byte |
||||
for { |
||||
if _, err := io.ReadFull(c, length[:]); err != nil { |
||||
return err |
||||
} |
||||
l := binary.BigEndian.Uint32(length[:]) |
||||
if l > maxAgentResponseBytes { |
||||
// We also cap requests.
|
||||
return fmt.Errorf("agent: request too large: %d", l) |
||||
} |
||||
|
||||
req := make([]byte, l) |
||||
if _, err := io.ReadFull(c, req); err != nil { |
||||
return err |
||||
} |
||||
|
||||
repData := s.processRequestBytes(req) |
||||
if len(repData) > maxAgentResponseBytes { |
||||
return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) |
||||
} |
||||
|
||||
binary.BigEndian.PutUint32(length[:], uint32(len(repData))) |
||||
if _, err := c.Write(length[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := c.Write(repData); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
@ -1,77 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
func TestServer(t *testing.T) { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
client := NewClient(c1) |
||||
|
||||
go ServeAgent(NewKeyring(), c2) |
||||
|
||||
testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) |
||||
} |
||||
|
||||
func TestLockServer(t *testing.T) { |
||||
testLockAgent(NewKeyring(), t) |
||||
} |
||||
|
||||
func TestSetupForwardAgent(t *testing.T) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
|
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
_, socket, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
serverConf := ssh.ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
incoming := make(chan *ssh.ServerConn, 1) |
||||
go func() { |
||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||
if err != nil { |
||||
t.Fatalf("Server: %v", err) |
||||
} |
||||
incoming <- conn |
||||
}() |
||||
|
||||
conf := ssh.ClientConfig{} |
||||
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) |
||||
if err != nil { |
||||
t.Fatalf("NewClientConn: %v", err) |
||||
} |
||||
client := ssh.NewClient(conn, chans, reqs) |
||||
|
||||
if err := ForwardToRemote(client, socket); err != nil { |
||||
t.Fatalf("SetupForwardAgent: %v", err) |
||||
} |
||||
|
||||
server := <-incoming |
||||
ch, reqs, err := server.OpenChannel(channelType, nil) |
||||
if err != nil { |
||||
t.Fatalf("OpenChannel(%q): %v", channelType, err) |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
agentClient := NewClient(ch) |
||||
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) |
||||
conn.Close() |
||||
} |
@ -1,64 +0,0 @@ |
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
||||
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
||||
// instances.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"fmt" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
"github.com/gogits/gogs/modules/crypto/ssh/testdata" |
||||
) |
||||
|
||||
var ( |
||||
testPrivateKeys map[string]interface{} |
||||
testSigners map[string]ssh.Signer |
||||
testPublicKeys map[string]ssh.PublicKey |
||||
) |
||||
|
||||
func init() { |
||||
var err error |
||||
|
||||
n := len(testdata.PEMBytes) |
||||
testPrivateKeys = make(map[string]interface{}, n) |
||||
testSigners = make(map[string]ssh.Signer, n) |
||||
testPublicKeys = make(map[string]ssh.PublicKey, n) |
||||
for t, k := range testdata.PEMBytes { |
||||
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) |
||||
} |
||||
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) |
||||
} |
||||
testPublicKeys[t] = testSigners[t].PublicKey() |
||||
} |
||||
|
||||
// Create a cert and sign it for use in tests.
|
||||
testCert := &ssh.Certificate{ |
||||
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
||||
ValidAfter: 0, // unix epoch
|
||||
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
||||
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||
Key: testPublicKeys["ecdsa"], |
||||
SignatureKey: testPublicKeys["rsa"], |
||||
Permissions: ssh.Permissions{ |
||||
CriticalOptions: map[string]string{}, |
||||
Extensions: map[string]string{}, |
||||
}, |
||||
} |
||||
testCert.SignCert(rand.Reader, testSigners["rsa"]) |
||||
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] |
||||
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) |
||||
} |
||||
} |
@ -1,122 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
) |
||||
|
||||
type server struct { |
||||
*ServerConn |
||||
chans <-chan NewChannel |
||||
} |
||||
|
||||
func newServer(c net.Conn, conf *ServerConfig) (*server, error) { |
||||
sconn, chans, reqs, err := NewServerConn(c, conf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
go DiscardRequests(reqs) |
||||
return &server{sconn, chans}, nil |
||||
} |
||||
|
||||
func (s *server) Accept() (NewChannel, error) { |
||||
n, ok := <-s.chans |
||||
if !ok { |
||||
return nil, io.EOF |
||||
} |
||||
return n, nil |
||||
} |
||||
|
||||
func sshPipe() (Conn, *server, error) { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
clientConf := ClientConfig{ |
||||
User: "user", |
||||
} |
||||
serverConf := ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||
done := make(chan *server, 1) |
||||
go func() { |
||||
server, err := newServer(c2, &serverConf) |
||||
if err != nil { |
||||
done <- nil |
||||
} |
||||
done <- server |
||||
}() |
||||
|
||||
client, _, reqs, err := NewClientConn(c1, "", &clientConf) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
server := <-done |
||||
if server == nil { |
||||
return nil, nil, errors.New("server handshake failed.") |
||||
} |
||||
go DiscardRequests(reqs) |
||||
|
||||
return client, server, nil |
||||
} |
||||
|
||||
func BenchmarkEndToEnd(b *testing.B) { |
||||
b.StopTimer() |
||||
|
||||
client, server, err := sshPipe() |
||||
if err != nil { |
||||
b.Fatalf("sshPipe: %v", err) |
||||
} |
||||
|
||||
defer client.Close() |
||||
defer server.Close() |
||||
|
||||
size := (1 << 20) |
||||
input := make([]byte, size) |
||||
output := make([]byte, size) |
||||
b.SetBytes(int64(size)) |
||||
done := make(chan int, 1) |
||||
|
||||
go func() { |
||||
newCh, err := server.Accept() |
||||
if err != nil { |
||||
b.Fatalf("Client: %v", err) |
||||
} |
||||
ch, incoming, err := newCh.Accept() |
||||
go DiscardRequests(incoming) |
||||
for i := 0; i < b.N; i++ { |
||||
if _, err := io.ReadFull(ch, output); err != nil { |
||||
b.Fatalf("ReadFull: %v", err) |
||||
} |
||||
} |
||||
ch.Close() |
||||
done <- 1 |
||||
}() |
||||
|
||||
ch, in, err := client.OpenChannel("speed", nil) |
||||
if err != nil { |
||||
b.Fatalf("OpenChannel: %v", err) |
||||
} |
||||
go DiscardRequests(in) |
||||
|
||||
b.ResetTimer() |
||||
b.StartTimer() |
||||
for i := 0; i < b.N; i++ { |
||||
if _, err := ch.Write(input); err != nil { |
||||
b.Fatalf("WriteFull: %v", err) |
||||
} |
||||
} |
||||
ch.Close() |
||||
b.StopTimer() |
||||
|
||||
<-done |
||||
} |
@ -1,98 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
) |
||||
|
||||
// buffer provides a linked list buffer for data exchange
|
||||
// between producer and consumer. Theoretically the buffer is
|
||||
// of unlimited capacity as it does no allocation of its own.
|
||||
type buffer struct { |
||||
// protects concurrent access to head, tail and closed
|
||||
*sync.Cond |
||||
|
||||
head *element // the buffer that will be read first
|
||||
tail *element // the buffer that will be read last
|
||||
|
||||
closed bool |
||||
} |
||||
|
||||
// An element represents a single link in a linked list.
|
||||
type element struct { |
||||
buf []byte |
||||
next *element |
||||
} |
||||
|
||||
// newBuffer returns an empty buffer that is not closed.
|
||||
func newBuffer() *buffer { |
||||
e := new(element) |
||||
b := &buffer{ |
||||
Cond: newCond(), |
||||
head: e, |
||||
tail: e, |
||||
} |
||||
return b |
||||
} |
||||
|
||||
// write makes buf available for Read to receive.
|
||||
// buf must not be modified after the call to write.
|
||||
func (b *buffer) write(buf []byte) { |
||||
b.Cond.L.Lock() |
||||
e := &element{buf: buf} |
||||
b.tail.next = e |
||||
b.tail = e |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
} |
||||
|
||||
// eof closes the buffer. Reads from the buffer once all
|
||||
// the data has been consumed will receive os.EOF.
|
||||
func (b *buffer) eof() error { |
||||
b.Cond.L.Lock() |
||||
b.closed = true |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
// Read reads data from the internal buffer in buf. Reads will block
|
||||
// if no data is available, or until the buffer is closed.
|
||||
func (b *buffer) Read(buf []byte) (n int, err error) { |
||||
b.Cond.L.Lock() |
||||
defer b.Cond.L.Unlock() |
||||
|
||||
for len(buf) > 0 { |
||||
// if there is data in b.head, copy it
|
||||
if len(b.head.buf) > 0 { |
||||
r := copy(buf, b.head.buf) |
||||
buf, b.head.buf = buf[r:], b.head.buf[r:] |
||||
n += r |
||||
continue |
||||
} |
||||
// if there is a next buffer, make it the head
|
||||
if len(b.head.buf) == 0 && b.head != b.tail { |
||||
b.head = b.head.next |
||||
continue |
||||
} |
||||
|
||||
// if at least one byte has been copied, return
|
||||
if n > 0 { |
||||
break |
||||
} |
||||
|
||||
// if nothing was read, and there is nothing outstanding
|
||||
// check to see if the buffer is closed.
|
||||
if b.closed { |
||||
err = io.EOF |
||||
break |
||||
} |
||||
// out of buffers, wait for producer
|
||||
b.Cond.Wait() |
||||
} |
||||
return |
||||
} |
@ -1,87 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"testing" |
||||
) |
||||
|
||||
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") |
||||
|
||||
func TestBufferReadwrite(t *testing.T) { |
||||
b := newBuffer() |
||||
b.write(alphabet[:10]) |
||||
r, _ := b.Read(make([]byte, 10)) |
||||
if r != 10 { |
||||
t.Fatalf("Expected written == read == 10, written: 10, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:5]) |
||||
r, _ = b.Read(make([]byte, 10)) |
||||
if r != 5 { |
||||
t.Fatalf("Expected written == read == 5, written: 5, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
r, _ = b.Read(make([]byte, 5)) |
||||
if r != 5 { |
||||
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:5]) |
||||
b.write(alphabet[5:15]) |
||||
r, _ = b.Read(make([]byte, 10)) |
||||
r2, _ := b.Read(make([]byte, 10)) |
||||
if r != 10 || r2 != 5 || 15 != r+r2 { |
||||
t.Fatal("Expected written == read == 15") |
||||
} |
||||
} |
||||
|
||||
func TestBufferClose(t *testing.T) { |
||||
b := newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
_, err := b.Read(make([]byte, 5)) |
||||
if err != nil { |
||||
t.Fatal("expected read of 5 to not return EOF") |
||||
} |
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
r, err := b.Read(make([]byte, 5)) |
||||
r2, err2 := b.Read(make([]byte, 10)) |
||||
if r != 5 || r2 != 5 || err != nil || err2 != nil { |
||||
t.Fatal("expected reads of 5 and 5") |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
r, err = b.Read(make([]byte, 5)) |
||||
r2, err2 = b.Read(make([]byte, 10)) |
||||
r3, err3 := b.Read(make([]byte, 10)) |
||||
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { |
||||
t.Fatal("expected reads of 5 and 5 and 0, with EOF") |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(make([]byte, 5)) |
||||
b.write(make([]byte, 10)) |
||||
b.eof() |
||||
r, err = b.Read(make([]byte, 9)) |
||||
r2, err2 = b.Read(make([]byte, 3)) |
||||
r3, err3 = b.Read(make([]byte, 3)) |
||||
r4, err4 := b.Read(make([]byte, 10)) |
||||
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { |
||||
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) |
||||
} |
||||
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { |
||||
t.Fatal("Expected written == read == 15", r, r2, r3, r4) |
||||
} |
||||
} |
@ -1,501 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sort" |
||||
"time" |
||||
) |
||||
|
||||
// These constants from [PROTOCOL.certkeys] represent the algorithm names
|
||||
// for certificate types supported by this package.
|
||||
const ( |
||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" |
||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" |
||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" |
||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" |
||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" |
||||
) |
||||
|
||||
// Certificate types distinguish between host and user
|
||||
// certificates. The values can be set in the CertType field of
|
||||
// Certificate.
|
||||
const ( |
||||
UserCert = 1 |
||||
HostCert = 2 |
||||
) |
||||
|
||||
// Signature represents a cryptographic signature.
|
||||
type Signature struct { |
||||
Format string |
||||
Blob []byte |
||||
} |
||||
|
||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
||||
// a certificate does not expire.
|
||||
const CertTimeInfinity = 1<<64 - 1 |
||||
|
||||
// An Certificate represents an OpenSSH certificate as defined in
|
||||
// [PROTOCOL.certkeys]?rev=1.8.
|
||||
type Certificate struct { |
||||
Nonce []byte |
||||
Key PublicKey |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []string |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
Permissions |
||||
Reserved []byte |
||||
SignatureKey PublicKey |
||||
Signature *Signature |
||||
} |
||||
|
||||
// genericCertData holds the key-independent part of the certificate data.
|
||||
// Overall, certificates contain an nonce, public key fields and
|
||||
// key-independent fields.
|
||||
type genericCertData struct { |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []byte |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
CriticalOptions []byte |
||||
Extensions []byte |
||||
Reserved []byte |
||||
SignatureKey []byte |
||||
Signature []byte |
||||
} |
||||
|
||||
func marshalStringList(namelist []string) []byte { |
||||
var to []byte |
||||
for _, name := range namelist { |
||||
s := struct{ N string }{name} |
||||
to = append(to, Marshal(&s)...) |
||||
} |
||||
return to |
||||
} |
||||
|
||||
type optionsTuple struct { |
||||
Key string |
||||
Value []byte |
||||
} |
||||
|
||||
type optionsTupleValue struct { |
||||
Value string |
||||
} |
||||
|
||||
// serialize a map of critical options or extensions
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty string value
|
||||
func marshalTuples(tups map[string]string) []byte { |
||||
keys := make([]string, 0, len(tups)) |
||||
for key := range tups { |
||||
keys = append(keys, key) |
||||
} |
||||
sort.Strings(keys) |
||||
|
||||
var ret []byte |
||||
for _, key := range keys { |
||||
s := optionsTuple{Key: key} |
||||
if value := tups[key]; len(value) > 0 { |
||||
s.Value = Marshal(&optionsTupleValue{value}) |
||||
} |
||||
ret = append(ret, Marshal(&s)...) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty option value
|
||||
func parseTuples(in []byte) (map[string]string, error) { |
||||
tups := map[string]string{} |
||||
var lastKey string |
||||
var haveLastKey bool |
||||
|
||||
for len(in) > 0 { |
||||
var key, val, extra []byte |
||||
var ok bool |
||||
|
||||
if key, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
keyStr := string(key) |
||||
// according to [PROTOCOL.certkeys], the names must be in
|
||||
// lexical order.
|
||||
if haveLastKey && keyStr <= lastKey { |
||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order") |
||||
} |
||||
lastKey, haveLastKey = keyStr, true |
||||
// the next field is a data field, which if non-empty has a string embedded
|
||||
if val, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(val) > 0 { |
||||
val, extra, ok = parseString(val) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(extra) > 0 { |
||||
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") |
||||
} |
||||
tups[keyStr] = string(val) |
||||
} else { |
||||
tups[keyStr] = "" |
||||
} |
||||
} |
||||
return tups, nil |
||||
} |
||||
|
||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) { |
||||
nonce, rest, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
|
||||
key, rest, err := parsePubKey(rest, privAlgo) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var g genericCertData |
||||
if err := Unmarshal(rest, &g); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c := &Certificate{ |
||||
Nonce: nonce, |
||||
Key: key, |
||||
Serial: g.Serial, |
||||
CertType: g.CertType, |
||||
KeyId: g.KeyId, |
||||
ValidAfter: g.ValidAfter, |
||||
ValidBefore: g.ValidBefore, |
||||
} |
||||
|
||||
for principals := g.ValidPrincipals; len(principals) > 0; { |
||||
principal, rest, ok := parseString(principals) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) |
||||
principals = rest |
||||
} |
||||
|
||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Extensions, err = parseTuples(g.Extensions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Reserved = g.Reserved |
||||
k, err := ParsePublicKey(g.SignatureKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.SignatureKey = k |
||||
c.Signature, rest, ok = parseSignatureBody(g.Signature) |
||||
if !ok || len(rest) > 0 { |
||||
return nil, errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
type openSSHCertSigner struct { |
||||
pub *Certificate |
||||
signer Signer |
||||
} |
||||
|
||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
||||
// private key is held by signer. It returns an error if the public key in cert
|
||||
// doesn't match the key used by signer.
|
||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { |
||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||
return nil, errors.New("ssh: signer and cert have different public key") |
||||
} |
||||
|
||||
return &openSSHCertSigner{cert, signer}, nil |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
return s.signer.Sign(rand, data) |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) PublicKey() PublicKey { |
||||
return s.pub |
||||
} |
||||
|
||||
const sourceAddressCriticalOption = "source-address" |
||||
|
||||
// CertChecker does the work of verifying a certificate. Its methods
|
||||
// can be plugged into ClientConfig.HostKeyCallback and
|
||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
||||
// minimally, the IsAuthority callback should be set.
|
||||
type CertChecker struct { |
||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
||||
// server application layer understands. These are only used
|
||||
// for user certificates.
|
||||
SupportedCriticalOptions []string |
||||
|
||||
// IsAuthority should return true if the key is recognized as
|
||||
// an authority. This allows for certificates to be signed by other
|
||||
// certificates.
|
||||
IsAuthority func(auth PublicKey) bool |
||||
|
||||
// Clock is used for verifying time stamps. If nil, time.Now
|
||||
// is used.
|
||||
Clock func() time.Time |
||||
|
||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
||||
// public key that is not a certificate. It must implement validation
|
||||
// of user keys or else, if nil, all such keys are rejected.
|
||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||
|
||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
||||
// public key that is not a certificate. It must implement host key
|
||||
// validation or else, if nil, all such keys are rejected.
|
||||
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error |
||||
|
||||
// IsRevoked is called for each certificate so that revocation checking
|
||||
// can be implemented. It should return true if the given certificate
|
||||
// is revoked and false otherwise. If nil, no certificates are
|
||||
// considered to have been revoked.
|
||||
IsRevoked func(cert *Certificate) bool |
||||
} |
||||
|
||||
// CheckHostKey checks a host key certificate. This method can be
|
||||
// plugged into ClientConfig.HostKeyCallback.
|
||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { |
||||
cert, ok := key.(*Certificate) |
||||
if !ok { |
||||
if c.HostKeyFallback != nil { |
||||
return c.HostKeyFallback(addr, remote, key) |
||||
} |
||||
return errors.New("ssh: non-certificate host key") |
||||
} |
||||
if cert.CertType != HostCert { |
||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) |
||||
} |
||||
|
||||
return c.CheckCert(addr, cert) |
||||
} |
||||
|
||||
// Authenticate checks a user certificate. Authenticate can be used as
|
||||
// a value for ServerConfig.PublicKeyCallback.
|
||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { |
||||
cert, ok := pubKey.(*Certificate) |
||||
if !ok { |
||||
if c.UserKeyFallback != nil { |
||||
return c.UserKeyFallback(conn, pubKey) |
||||
} |
||||
return nil, errors.New("ssh: normal key pairs not accepted") |
||||
} |
||||
|
||||
if cert.CertType != UserCert { |
||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) |
||||
} |
||||
|
||||
if err := c.CheckCert(conn.User(), cert); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &cert.Permissions, nil |
||||
} |
||||
|
||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
||||
// the signature of the certificate.
|
||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { |
||||
if c.IsRevoked != nil && c.IsRevoked(cert) { |
||||
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) |
||||
} |
||||
|
||||
for opt, _ := range cert.CriticalOptions { |
||||
// sourceAddressCriticalOption will be enforced by
|
||||
// serverAuthenticate
|
||||
if opt == sourceAddressCriticalOption { |
||||
continue |
||||
} |
||||
|
||||
found := false |
||||
for _, supp := range c.SupportedCriticalOptions { |
||||
if supp == opt { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) |
||||
} |
||||
} |
||||
|
||||
if len(cert.ValidPrincipals) > 0 { |
||||
// By default, certs are valid for all users/hosts.
|
||||
found := false |
||||
for _, p := range cert.ValidPrincipals { |
||||
if p == principal { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) |
||||
} |
||||
} |
||||
|
||||
if !c.IsAuthority(cert.SignatureKey) { |
||||
return fmt.Errorf("ssh: certificate signed by unrecognized authority") |
||||
} |
||||
|
||||
clock := c.Clock |
||||
if clock == nil { |
||||
clock = time.Now |
||||
} |
||||
|
||||
unixNow := clock().Unix() |
||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { |
||||
return fmt.Errorf("ssh: cert is not yet valid") |
||||
} |
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { |
||||
return fmt.Errorf("ssh: cert has expired") |
||||
} |
||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { |
||||
return fmt.Errorf("ssh: certificate signature does not verify") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// SignCert sets c.SignatureKey to the authority's public key and stores a
|
||||
// Signature, by authority, in the certificate.
|
||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { |
||||
c.Nonce = make([]byte, 32) |
||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil { |
||||
return err |
||||
} |
||||
c.SignatureKey = authority.PublicKey() |
||||
|
||||
sig, err := authority.Sign(rand, c.bytesForSigning()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
c.Signature = sig |
||||
return nil |
||||
} |
||||
|
||||
var certAlgoNames = map[string]string{ |
||||
KeyAlgoRSA: CertAlgoRSAv01, |
||||
KeyAlgoDSA: CertAlgoDSAv01, |
||||
KeyAlgoECDSA256: CertAlgoECDSA256v01, |
||||
KeyAlgoECDSA384: CertAlgoECDSA384v01, |
||||
KeyAlgoECDSA521: CertAlgoECDSA521v01, |
||||
} |
||||
|
||||
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
|
||||
// Panics if a non-certificate algorithm is passed.
|
||||
func certToPrivAlgo(algo string) string { |
||||
for privAlgo, pubAlgo := range certAlgoNames { |
||||
if pubAlgo == algo { |
||||
return privAlgo |
||||
} |
||||
} |
||||
panic("unknown cert algorithm") |
||||
} |
||||
|
||||
func (cert *Certificate) bytesForSigning() []byte { |
||||
c2 := *cert |
||||
c2.Signature = nil |
||||
out := c2.Marshal() |
||||
// Drop trailing signature length.
|
||||
return out[:len(out)-4] |
||||
} |
||||
|
||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
||||
// PublicKey interface.
|
||||
func (c *Certificate) Marshal() []byte { |
||||
generic := genericCertData{ |
||||
Serial: c.Serial, |
||||
CertType: c.CertType, |
||||
KeyId: c.KeyId, |
||||
ValidPrincipals: marshalStringList(c.ValidPrincipals), |
||||
ValidAfter: uint64(c.ValidAfter), |
||||
ValidBefore: uint64(c.ValidBefore), |
||||
CriticalOptions: marshalTuples(c.CriticalOptions), |
||||
Extensions: marshalTuples(c.Extensions), |
||||
Reserved: c.Reserved, |
||||
SignatureKey: c.SignatureKey.Marshal(), |
||||
} |
||||
if c.Signature != nil { |
||||
generic.Signature = Marshal(c.Signature) |
||||
} |
||||
genericBytes := Marshal(&generic) |
||||
keyBytes := c.Key.Marshal() |
||||
_, keyBytes, _ = parseString(keyBytes) |
||||
prefix := Marshal(&struct { |
||||
Name string |
||||
Nonce []byte |
||||
Key []byte `ssh:"rest"` |
||||
}{c.Type(), c.Nonce, keyBytes}) |
||||
|
||||
result := make([]byte, 0, len(prefix)+len(genericBytes)) |
||||
result = append(result, prefix...) |
||||
result = append(result, genericBytes...) |
||||
return result |
||||
} |
||||
|
||||
// Type returns the key name. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Type() string { |
||||
algo, ok := certAlgoNames[c.Key.Type()] |
||||
if !ok { |
||||
panic("unknown cert key type") |
||||
} |
||||
return algo |
||||
} |
||||
|
||||
// Verify verifies a signature against the certificate's public
|
||||
// key. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Verify(data []byte, sig *Signature) error { |
||||
return c.Key.Verify(data, sig) |
||||
} |
||||
|
||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
format, in, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out = &Signature{ |
||||
Format: string(format), |
||||
} |
||||
|
||||
if out.Blob, in, ok = parseString(in); !ok { |
||||
return |
||||
} |
||||
|
||||
return out, in, ok |
||||
} |
||||
|
||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
sigBytes, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out, trailing, ok := parseSignatureBody(sigBytes) |
||||
if !ok || len(trailing) > 0 { |
||||
return nil, nil, false |
||||
} |
||||
return |
||||
} |
@ -1,216 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
// Cert generated by ssh-keygen 6.0p1 Debian-4.
|
||||
// % ssh-keygen -s ca-key -I test user-key
|
||||
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` |
||||
|
||||
func TestParseCert(t *testing.T) { |
||||
authKeyBytes := []byte(exampleSSHCert) |
||||
|
||||
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
if len(rest) > 0 { |
||||
t.Errorf("rest: got %q, want empty", rest) |
||||
} |
||||
|
||||
if _, ok := key.(*Certificate); !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
|
||||
marshaled := MarshalAuthorizedKey(key) |
||||
// Before comparison, remove the trailing newline that
|
||||
// MarshalAuthorizedKey adds.
|
||||
marshaled = marshaled[:len(marshaled)-1] |
||||
if !bytes.Equal(authKeyBytes, marshaled) { |
||||
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||
} |
||||
} |
||||
|
||||
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
|
||||
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
|
||||
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
|
||||
// Critical Options:
|
||||
// force-command /bin/sleep
|
||||
// source-address 192.168.1.0/24
|
||||
// Extensions:
|
||||
// permit-X11-forwarding
|
||||
// permit-agent-forwarding
|
||||
// permit-port-forwarding
|
||||
// permit-pty
|
||||
// permit-user-rc
|
||||
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` |
||||
|
||||
func TestParseCertWithOptions(t *testing.T) { |
||||
opts := map[string]string{ |
||||
"source-address": "192.168.1.0/24", |
||||
"force-command": "/bin/sleep", |
||||
} |
||||
exts := map[string]string{ |
||||
"permit-X11-forwarding": "", |
||||
"permit-agent-forwarding": "", |
||||
"permit-port-forwarding": "", |
||||
"permit-pty": "", |
||||
"permit-user-rc": "", |
||||
} |
||||
authKeyBytes := []byte(exampleSSHCertWithOptions) |
||||
|
||||
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
if len(rest) > 0 { |
||||
t.Errorf("rest: got %q, want empty", rest) |
||||
} |
||||
cert, ok := key.(*Certificate) |
||||
if !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
if !reflect.DeepEqual(cert.CriticalOptions, opts) { |
||||
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) |
||||
} |
||||
if !reflect.DeepEqual(cert.Extensions, exts) { |
||||
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) |
||||
} |
||||
marshaled := MarshalAuthorizedKey(key) |
||||
// Before comparison, remove the trailing newline that
|
||||
// MarshalAuthorizedKey adds.
|
||||
marshaled = marshaled[:len(marshaled)-1] |
||||
if !bytes.Equal(authKeyBytes, marshaled) { |
||||
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||
} |
||||
} |
||||
|
||||
func TestValidateCert(t *testing.T) { |
||||
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
validCert, ok := key.(*Certificate) |
||||
if !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
checker := CertChecker{} |
||||
checker.IsAuthority = func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) |
||||
} |
||||
|
||||
if err := checker.CheckCert("user", validCert); err != nil { |
||||
t.Errorf("Unable to validate certificate: %v", err) |
||||
} |
||||
invalidCert := &Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
SignatureKey: testPublicKeys["ecdsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
Signature: &Signature{}, |
||||
} |
||||
if err := checker.CheckCert("user", invalidCert); err == nil { |
||||
t.Error("Invalid cert signature passed validation") |
||||
} |
||||
} |
||||
|
||||
func TestValidateCertTime(t *testing.T) { |
||||
cert := Certificate{ |
||||
ValidPrincipals: []string{"user"}, |
||||
Key: testPublicKeys["rsa"], |
||||
ValidAfter: 50, |
||||
ValidBefore: 100, |
||||
} |
||||
|
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
for ts, ok := range map[int64]bool{ |
||||
25: false, |
||||
50: true, |
||||
99: true, |
||||
100: false, |
||||
125: false, |
||||
} { |
||||
checker := CertChecker{ |
||||
Clock: func() time.Time { return time.Unix(ts, 0) }, |
||||
} |
||||
checker.IsAuthority = func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), |
||||
testPublicKeys["ecdsa"].Marshal()) |
||||
} |
||||
|
||||
if v := checker.CheckCert("user", &cert); (v == nil) != ok { |
||||
t.Errorf("Authenticate(%d): %v", ts, v) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// TODO(hanwen): tests for
|
||||
//
|
||||
// host keys:
|
||||
// * fallbacks
|
||||
|
||||
func TestHostKeyCert(t *testing.T) { |
||||
cert := &Certificate{ |
||||
ValidPrincipals: []string{"hostname", "hostname.domain"}, |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
CertType: HostCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
checker := &CertChecker{ |
||||
IsAuthority: func(p PublicKey) bool { |
||||
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) |
||||
}, |
||||
} |
||||
|
||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||
if err != nil { |
||||
t.Errorf("NewCertSigner: %v", err) |
||||
} |
||||
|
||||
for _, name := range []string{"hostname", "otherhost"} { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
errc := make(chan error) |
||||
|
||||
go func() { |
||||
conf := ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
conf.AddHostKey(certSigner) |
||||
_, _, _, err := NewServerConn(c1, &conf) |
||||
errc <- err |
||||
}() |
||||
|
||||
config := &ClientConfig{ |
||||
User: "user", |
||||
HostKeyCallback: checker.CheckHostKey, |
||||
} |
||||
_, _, _, err = NewClientConn(c2, name, config) |
||||
|
||||
succeed := name == "hostname" |
||||
if (err == nil) != succeed { |
||||
t.Fatalf("NewClientConn(%q): %v", name, err) |
||||
} |
||||
|
||||
err = <-errc |
||||
if (err == nil) != succeed { |
||||
t.Fatalf("NewServerConn(%q): %v", name, err) |
||||
} |
||||
} |
||||
} |
@ -1,631 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"sync" |
||||
) |
||||
|
||||
const ( |
||||
minPacketLength = 9 |
||||
// channelMaxPacket contains the maximum number of bytes that will be
|
||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||
// the minimum.
|
||||
channelMaxPacket = 1 << 15 |
||||
// We follow OpenSSH here.
|
||||
channelWindowSize = 64 * channelMaxPacket |
||||
) |
||||
|
||||
// NewChannel represents an incoming request to a channel. It must either be
|
||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
||||
type NewChannel interface { |
||||
// Accept accepts the channel creation request. It returns the Channel
|
||||
// and a Go channel containing SSH requests. The Go channel must be
|
||||
// serviced otherwise the Channel will hang.
|
||||
Accept() (Channel, <-chan *Request, error) |
||||
|
||||
// Reject rejects the channel creation request. After calling
|
||||
// this, no other methods on the Channel may be called.
|
||||
Reject(reason RejectionReason, message string) error |
||||
|
||||
// ChannelType returns the type of the channel, as supplied by the
|
||||
// client.
|
||||
ChannelType() string |
||||
|
||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
||||
// by the client. This data is specific to the channel type.
|
||||
ExtraData() []byte |
||||
} |
||||
|
||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
||||
// that is multiplexed over an SSH connection.
|
||||
type Channel interface { |
||||
// Read reads up to len(data) bytes from the channel.
|
||||
Read(data []byte) (int, error) |
||||
|
||||
// Write writes len(data) bytes to the channel.
|
||||
Write(data []byte) (int, error) |
||||
|
||||
// Close signals end of channel use. No data may be sent after this
|
||||
// call.
|
||||
Close() error |
||||
|
||||
// CloseWrite signals the end of sending in-band
|
||||
// data. Requests may still be sent, and the other side may
|
||||
// still send data
|
||||
CloseWrite() error |
||||
|
||||
// SendRequest sends a channel request. If wantReply is true,
|
||||
// it will wait for a reply and return the result as a
|
||||
// boolean, otherwise the return value will be false. Channel
|
||||
// requests are out-of-band messages so they may be sent even
|
||||
// if the data stream is closed or blocked by flow control.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error) |
||||
|
||||
// Stderr returns an io.ReadWriter that writes to this channel
|
||||
// with the extended data type set to stderr. Stderr may
|
||||
// safely be read and written from a different goroutine than
|
||||
// Read and Write respectively.
|
||||
Stderr() io.ReadWriter |
||||
} |
||||
|
||||
// Request is a request sent outside of the normal stream of
|
||||
// data. Requests can either be specific to an SSH channel, or they
|
||||
// can be global.
|
||||
type Request struct { |
||||
Type string |
||||
WantReply bool |
||||
Payload []byte |
||||
|
||||
ch *channel |
||||
mux *mux |
||||
} |
||||
|
||||
// Reply sends a response to a request. It must be called for all requests
|
||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
||||
// ignored for replies to channel-specific requests.
|
||||
func (r *Request) Reply(ok bool, payload []byte) error { |
||||
if !r.WantReply { |
||||
return nil |
||||
} |
||||
|
||||
if r.ch == nil { |
||||
return r.mux.ackRequest(ok, payload) |
||||
} |
||||
|
||||
return r.ch.ackRequest(ok) |
||||
} |
||||
|
||||
// RejectionReason is an enumeration used when rejecting channel creation
|
||||
// requests. See RFC 4254, section 5.1.
|
||||
type RejectionReason uint32 |
||||
|
||||
const ( |
||||
Prohibited RejectionReason = iota + 1 |
||||
ConnectionFailed |
||||
UnknownChannelType |
||||
ResourceShortage |
||||
) |
||||
|
||||
// String converts the rejection reason to human readable form.
|
||||
func (r RejectionReason) String() string { |
||||
switch r { |
||||
case Prohibited: |
||||
return "administratively prohibited" |
||||
case ConnectionFailed: |
||||
return "connect failed" |
||||
case UnknownChannelType: |
||||
return "unknown channel type" |
||||
case ResourceShortage: |
||||
return "resource shortage" |
||||
} |
||||
return fmt.Sprintf("unknown reason %d", int(r)) |
||||
} |
||||
|
||||
func min(a uint32, b int) uint32 { |
||||
if a < uint32(b) { |
||||
return a |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
type channelDirection uint8 |
||||
|
||||
const ( |
||||
channelInbound channelDirection = iota |
||||
channelOutbound |
||||
) |
||||
|
||||
// channel is an implementation of the Channel interface that works
|
||||
// with the mux class.
|
||||
type channel struct { |
||||
// R/O after creation
|
||||
chanType string |
||||
extraData []byte |
||||
localId, remoteId uint32 |
||||
|
||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
||||
// payload sizes of normal and extended data packets for
|
||||
// receiving and sending, respectively. The wire packet will
|
||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
||||
maxIncomingPayload uint32 |
||||
maxRemotePayload uint32 |
||||
|
||||
mux *mux |
||||
|
||||
// decided is set to true if an accept or reject message has been sent
|
||||
// (for outbound channels) or received (for inbound channels).
|
||||
decided bool |
||||
|
||||
// direction contains either channelOutbound, for channels created
|
||||
// locally, or channelInbound, for channels created by the peer.
|
||||
direction channelDirection |
||||
|
||||
// Pending internal channel messages.
|
||||
msg chan interface{} |
||||
|
||||
// Since requests have no ID, there can be only one request
|
||||
// with WantReply=true outstanding. This lock is held by a
|
||||
// goroutine that has such an outgoing request pending.
|
||||
sentRequestMu sync.Mutex |
||||
|
||||
incomingRequests chan *Request |
||||
|
||||
sentEOF bool |
||||
|
||||
// thread-safe data
|
||||
remoteWin window |
||||
pending *buffer |
||||
extPending *buffer |
||||
|
||||
// windowMu protects myWindow, the flow-control window.
|
||||
windowMu sync.Mutex |
||||
myWindow uint32 |
||||
|
||||
// writeMu serializes calls to mux.conn.writePacket() and
|
||||
// protects sentClose and packetPool. This mutex must be
|
||||
// different from windowMu, as writePacket can block if there
|
||||
// is a key exchange pending.
|
||||
writeMu sync.Mutex |
||||
sentClose bool |
||||
|
||||
// packetPool has a buffer for each extended channel ID to
|
||||
// save allocations during writes.
|
||||
packetPool map[uint32][]byte |
||||
} |
||||
|
||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
||||
// sentClose. This method takes the lock c.writeMu.
|
||||
func (c *channel) writePacket(packet []byte) error { |
||||
c.writeMu.Lock() |
||||
if c.sentClose { |
||||
c.writeMu.Unlock() |
||||
return io.EOF |
||||
} |
||||
c.sentClose = (packet[0] == msgChannelClose) |
||||
err := c.mux.conn.writePacket(packet) |
||||
c.writeMu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *channel) sendMessage(msg interface{}) error { |
||||
if debugMux { |
||||
log.Printf("send %d: %#v", c.mux.chanList.offset, msg) |
||||
} |
||||
|
||||
p := Marshal(msg) |
||||
binary.BigEndian.PutUint32(p[1:], c.remoteId) |
||||
return c.writePacket(p) |
||||
} |
||||
|
||||
// WriteExtended writes data to a specific extended stream. These streams are
|
||||
// used, for example, for stderr.
|
||||
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { |
||||
if c.sentEOF { |
||||
return 0, io.EOF |
||||
} |
||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
||||
opCode := byte(msgChannelData) |
||||
headerLength := uint32(9) |
||||
if extendedCode > 0 { |
||||
headerLength += 4 |
||||
opCode = msgChannelExtendedData |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
packet := c.packetPool[extendedCode] |
||||
// We don't remove the buffer from packetPool, so
|
||||
// WriteExtended calls from different goroutines will be
|
||||
// flagged as errors by the race detector.
|
||||
c.writeMu.Unlock() |
||||
|
||||
for len(data) > 0 { |
||||
space := min(c.maxRemotePayload, len(data)) |
||||
if space, err = c.remoteWin.reserve(space); err != nil { |
||||
return n, err |
||||
} |
||||
if want := headerLength + space; uint32(cap(packet)) < want { |
||||
packet = make([]byte, want) |
||||
} else { |
||||
packet = packet[:want] |
||||
} |
||||
|
||||
todo := data[:space] |
||||
|
||||
packet[0] = opCode |
||||
binary.BigEndian.PutUint32(packet[1:], c.remoteId) |
||||
if extendedCode > 0 { |
||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) |
||||
} |
||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) |
||||
copy(packet[headerLength:], todo) |
||||
if err = c.writePacket(packet); err != nil { |
||||
return n, err |
||||
} |
||||
|
||||
n += len(todo) |
||||
data = data[len(todo):] |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
c.packetPool[extendedCode] = packet |
||||
c.writeMu.Unlock() |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) handleData(packet []byte) error { |
||||
headerLen := 9 |
||||
isExtendedData := packet[0] == msgChannelExtendedData |
||||
if isExtendedData { |
||||
headerLen = 13 |
||||
} |
||||
if len(packet) < headerLen { |
||||
// malformed data packet
|
||||
return parseError(packet[0]) |
||||
} |
||||
|
||||
var extended uint32 |
||||
if isExtendedData { |
||||
extended = binary.BigEndian.Uint32(packet[5:]) |
||||
} |
||||
|
||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) |
||||
if length == 0 { |
||||
return nil |
||||
} |
||||
if length > c.maxIncomingPayload { |
||||
// TODO(hanwen): should send Disconnect?
|
||||
return errors.New("ssh: incoming packet exceeds maximum payload size") |
||||
} |
||||
|
||||
data := packet[headerLen:] |
||||
if length != uint32(len(data)) { |
||||
return errors.New("ssh: wrong packet length") |
||||
} |
||||
|
||||
c.windowMu.Lock() |
||||
if c.myWindow < length { |
||||
c.windowMu.Unlock() |
||||
// TODO(hanwen): should send Disconnect with reason?
|
||||
return errors.New("ssh: remote side wrote too much") |
||||
} |
||||
c.myWindow -= length |
||||
c.windowMu.Unlock() |
||||
|
||||
if extended == 1 { |
||||
c.extPending.write(data) |
||||
} else if extended > 0 { |
||||
// discard other extended data.
|
||||
} else { |
||||
c.pending.write(data) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) adjustWindow(n uint32) error { |
||||
c.windowMu.Lock() |
||||
// Since myWindow is managed on our side, and can never exceed
|
||||
// the initial window setting, we don't worry about overflow.
|
||||
c.myWindow += uint32(n) |
||||
c.windowMu.Unlock() |
||||
return c.sendMessage(windowAdjustMsg{ |
||||
AdditionalBytes: uint32(n), |
||||
}) |
||||
} |
||||
|
||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { |
||||
switch extended { |
||||
case 1: |
||||
n, err = c.extPending.Read(data) |
||||
case 0: |
||||
n, err = c.pending.Read(data) |
||||
default: |
||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) |
||||
} |
||||
|
||||
if n > 0 { |
||||
err = c.adjustWindow(uint32(n)) |
||||
// sendWindowAdjust can return io.EOF if the remote
|
||||
// peer has closed the connection, however we want to
|
||||
// defer forwarding io.EOF to the caller of Read until
|
||||
// the buffer has been drained.
|
||||
if n > 0 && err == io.EOF { |
||||
err = nil |
||||
} |
||||
} |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) close() { |
||||
c.pending.eof() |
||||
c.extPending.eof() |
||||
close(c.msg) |
||||
close(c.incomingRequests) |
||||
c.writeMu.Lock() |
||||
// This is not necesary for a normal channel teardown, but if
|
||||
// there was another error, it is.
|
||||
c.sentClose = true |
||||
c.writeMu.Unlock() |
||||
// Unblock writers.
|
||||
c.remoteWin.close() |
||||
} |
||||
|
||||
// responseMessageReceived is called when a success or failure message is
|
||||
// received on a channel to check that such a message is reasonable for the
|
||||
// given channel.
|
||||
func (c *channel) responseMessageReceived() error { |
||||
if c.direction == channelInbound { |
||||
return errors.New("ssh: channel response message received on inbound channel") |
||||
} |
||||
if c.decided { |
||||
return errors.New("ssh: duplicate response received for channel") |
||||
} |
||||
c.decided = true |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) handlePacket(packet []byte) error { |
||||
switch packet[0] { |
||||
case msgChannelData, msgChannelExtendedData: |
||||
return c.handleData(packet) |
||||
case msgChannelClose: |
||||
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) |
||||
c.mux.chanList.remove(c.localId) |
||||
c.close() |
||||
return nil |
||||
case msgChannelEOF: |
||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
||||
// it is logical to signal EOF at the same time.
|
||||
c.extPending.eof() |
||||
c.pending.eof() |
||||
return nil |
||||
} |
||||
|
||||
decoded, err := decode(packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch msg := decoded.(type) { |
||||
case *channelOpenFailureMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
c.mux.chanList.remove(msg.PeersId) |
||||
c.msg <- msg |
||||
case *channelOpenConfirmMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) |
||||
} |
||||
c.remoteId = msg.MyId |
||||
c.maxRemotePayload = msg.MaxPacketSize |
||||
c.remoteWin.add(msg.MyWindow) |
||||
c.msg <- msg |
||||
case *windowAdjustMsg: |
||||
if !c.remoteWin.add(msg.AdditionalBytes) { |
||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) |
||||
} |
||||
case *channelRequestMsg: |
||||
req := Request{ |
||||
Type: msg.Request, |
||||
WantReply: msg.WantReply, |
||||
Payload: msg.RequestSpecificData, |
||||
ch: c, |
||||
} |
||||
|
||||
c.incomingRequests <- &req |
||||
default: |
||||
c.msg <- msg |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { |
||||
ch := &channel{ |
||||
remoteWin: window{Cond: newCond()}, |
||||
myWindow: channelWindowSize, |
||||
pending: newBuffer(), |
||||
extPending: newBuffer(), |
||||
direction: direction, |
||||
incomingRequests: make(chan *Request, 16), |
||||
msg: make(chan interface{}, 16), |
||||
chanType: chanType, |
||||
extraData: extraData, |
||||
mux: m, |
||||
packetPool: make(map[uint32][]byte), |
||||
} |
||||
ch.localId = m.chanList.add(ch) |
||||
return ch |
||||
} |
||||
|
||||
var errUndecided = errors.New("ssh: must Accept or Reject channel") |
||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") |
||||
|
||||
type extChannel struct { |
||||
code uint32 |
||||
ch *channel |
||||
} |
||||
|
||||
func (e *extChannel) Write(data []byte) (n int, err error) { |
||||
return e.ch.WriteExtended(data, e.code) |
||||
} |
||||
|
||||
func (e *extChannel) Read(data []byte) (n int, err error) { |
||||
return e.ch.ReadExtended(data, e.code) |
||||
} |
||||
|
||||
func (c *channel) Accept() (Channel, <-chan *Request, error) { |
||||
if c.decided { |
||||
return nil, nil, errDecidedAlready |
||||
} |
||||
c.maxIncomingPayload = channelMaxPacket |
||||
confirm := channelOpenConfirmMsg{ |
||||
PeersId: c.remoteId, |
||||
MyId: c.localId, |
||||
MyWindow: c.myWindow, |
||||
MaxPacketSize: c.maxIncomingPayload, |
||||
} |
||||
c.decided = true |
||||
if err := c.sendMessage(confirm); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c, c.incomingRequests, nil |
||||
} |
||||
|
||||
func (ch *channel) Reject(reason RejectionReason, message string) error { |
||||
if ch.decided { |
||||
return errDecidedAlready |
||||
} |
||||
reject := channelOpenFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
Reason: reason, |
||||
Message: message, |
||||
Language: "en", |
||||
} |
||||
ch.decided = true |
||||
return ch.sendMessage(reject) |
||||
} |
||||
|
||||
func (ch *channel) Read(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.ReadExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) Write(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.WriteExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) CloseWrite() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
ch.sentEOF = true |
||||
return ch.sendMessage(channelEOFMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
func (ch *channel) Close() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
return ch.sendMessage(channelCloseMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
||||
func (ch *channel) Extended(code uint32) io.ReadWriter { |
||||
if !ch.decided { |
||||
return nil |
||||
} |
||||
return &extChannel{code, ch} |
||||
} |
||||
|
||||
func (ch *channel) Stderr() io.ReadWriter { |
||||
return ch.Extended(1) |
||||
} |
||||
|
||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||
if !ch.decided { |
||||
return false, errUndecided |
||||
} |
||||
|
||||
if wantReply { |
||||
ch.sentRequestMu.Lock() |
||||
defer ch.sentRequestMu.Unlock() |
||||
} |
||||
|
||||
msg := channelRequestMsg{ |
||||
PeersId: ch.remoteId, |
||||
Request: name, |
||||
WantReply: wantReply, |
||||
RequestSpecificData: payload, |
||||
} |
||||
|
||||
if err := ch.sendMessage(msg); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
if wantReply { |
||||
m, ok := (<-ch.msg) |
||||
if !ok { |
||||
return false, io.EOF |
||||
} |
||||
switch m.(type) { |
||||
case *channelRequestFailureMsg: |
||||
return false, nil |
||||
case *channelRequestSuccessMsg: |
||||
return true, nil |
||||
default: |
||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) |
||||
} |
||||
} |
||||
|
||||
return false, nil |
||||
} |
||||
|
||||
// ackRequest either sends an ack or nack to the channel request.
|
||||
func (ch *channel) ackRequest(ok bool) error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
var msg interface{} |
||||
if !ok { |
||||
msg = channelRequestFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} else { |
||||
msg = channelRequestSuccessMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} |
||||
return ch.sendMessage(msg) |
||||
} |
||||
|
||||
func (ch *channel) ChannelType() string { |
||||
return ch.chanType |
||||
} |
||||
|
||||
func (ch *channel) ExtraData() []byte { |
||||
return ch.extraData |
||||
} |
@ -1,549 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/rc4" |
||||
"crypto/subtle" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"hash" |
||||
"io" |
||||
"io/ioutil" |
||||
) |
||||
|
||||
const ( |
||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||
|
||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
||||
// waffles on about reasonable limits.
|
||||
//
|
||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
||||
// the same. maxPacket is also used to ensure that uint32
|
||||
// length fields do not overflow, so it should remain well
|
||||
// below 4G.
|
||||
maxPacket = 256 * 1024 |
||||
) |
||||
|
||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||
// by the transport before the first key-exchange.
|
||||
type noneCipher struct{} |
||||
|
||||
func (c noneCipher) XORKeyStream(dst, src []byte) { |
||||
copy(dst, src) |
||||
} |
||||
|
||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return cipher.NewCTR(c, iv), nil |
||||
} |
||||
|
||||
func newRC4(key, iv []byte) (cipher.Stream, error) { |
||||
return rc4.NewCipher(key) |
||||
} |
||||
|
||||
type streamCipherMode struct { |
||||
keySize int |
||||
ivSize int |
||||
skip int |
||||
createFunc func(key, iv []byte) (cipher.Stream, error) |
||||
} |
||||
|
||||
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { |
||||
if len(key) < c.keySize { |
||||
panic("ssh: key length too small for cipher") |
||||
} |
||||
if len(iv) < c.ivSize { |
||||
panic("ssh: iv too small for cipher") |
||||
} |
||||
|
||||
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var streamDump []byte |
||||
if c.skip > 0 { |
||||
streamDump = make([]byte, 512) |
||||
} |
||||
|
||||
for remainingToDump := c.skip; remainingToDump > 0; { |
||||
dumpThisTime := remainingToDump |
||||
if dumpThisTime > len(streamDump) { |
||||
dumpThisTime = len(streamDump) |
||||
} |
||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) |
||||
remainingToDump -= dumpThisTime |
||||
} |
||||
|
||||
return stream, nil |
||||
} |
||||
|
||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
||||
// are not supported and will not be negotiated, even if explicitly requested in
|
||||
// ClientConfig.Crypto.Ciphers.
|
||||
var cipherModes = map[string]*streamCipherMode{ |
||||
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
||||
// are defined in the order specified in the RFC.
|
||||
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, |
||||
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, |
||||
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, |
||||
|
||||
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
||||
// They are defined in the order specified in the RFC.
|
||||
"arcfour128": {16, 0, 1536, newRC4}, |
||||
"arcfour256": {32, 0, 1536, newRC4}, |
||||
|
||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
||||
// RC4) has problems with weak keys, and should be used with caution."
|
||||
// RFC4345 introduces improved versions of Arcfour.
|
||||
"arcfour": {16, 0, 0, newRC4}, |
||||
|
||||
// AES-GCM is not a stream cipher, so it is constructed with a
|
||||
// special case. If we add any more non-stream ciphers, we
|
||||
// should invest a cleaner way to do this.
|
||||
gcmCipherID: {16, 12, 0, nil}, |
||||
|
||||
// insecure cipher, see http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf
|
||||
// uncomment below to enable it.
|
||||
// aes128cbcID: {16, aes.BlockSize, 0, nil},
|
||||
} |
||||
|
||||
// prefixLen is the length of the packet prefix that contains the packet length
|
||||
// and number of padding bytes.
|
||||
const prefixLen = 5 |
||||
|
||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
||||
type streamPacketCipher struct { |
||||
mac hash.Hash |
||||
cipher cipher.Stream |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
prefix [prefixLen]byte |
||||
seqNumBytes [4]byte |
||||
padding [2 * packetSizeMultiple]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
} |
||||
|
||||
// readPacket reads and decrypt a single packet from the reader argument.
|
||||
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
length := binary.BigEndian.Uint32(s.prefix[0:4]) |
||||
paddingLength := uint32(s.prefix[4]) |
||||
|
||||
var macSize uint32 |
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
s.mac.Write(s.prefix[:]) |
||||
macSize = uint32(s.mac.Size()) |
||||
} |
||||
|
||||
if length <= paddingLength+1 { |
||||
return nil, errors.New("ssh: invalid packet length, packet too small") |
||||
} |
||||
|
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: invalid packet length, packet too large") |
||||
} |
||||
|
||||
// the maxPacket check above ensures that length-1+macSize
|
||||
// does not overflow.
|
||||
if uint32(cap(s.packetData)) < length-1+macSize { |
||||
s.packetData = make([]byte, length-1+macSize) |
||||
} else { |
||||
s.packetData = s.packetData[:length-1+macSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, s.packetData); err != nil { |
||||
return nil, err |
||||
} |
||||
mac := s.packetData[length-1:] |
||||
data := s.packetData[:length-1] |
||||
s.cipher.XORKeyStream(data, data) |
||||
|
||||
if s.mac != nil { |
||||
s.mac.Write(data) |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { |
||||
return nil, errors.New("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return s.packetData[:length-paddingLength-1], nil |
||||
} |
||||
|
||||
// writePacket encrypts and sends a packet of data to the writer argument
|
||||
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
if len(packet) > maxPacket { |
||||
return errors.New("ssh: packet too large") |
||||
} |
||||
|
||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple |
||||
if paddingLength < 4 { |
||||
paddingLength += packetSizeMultiple |
||||
} |
||||
|
||||
length := len(packet) + 1 + paddingLength |
||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) |
||||
s.prefix[4] = byte(paddingLength) |
||||
padding := s.padding[:paddingLength] |
||||
if _, err := io.ReadFull(rand, padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
s.mac.Write(s.prefix[:]) |
||||
s.mac.Write(packet) |
||||
s.mac.Write(padding) |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
s.cipher.XORKeyStream(packet, packet) |
||||
s.cipher.XORKeyStream(padding, padding) |
||||
|
||||
if _, err := w.Write(s.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(packet); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if _, err := w.Write(s.macResult); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
type gcmCipher struct { |
||||
aead cipher.AEAD |
||||
prefix [4]byte |
||||
iv []byte |
||||
buf []byte |
||||
} |
||||
|
||||
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
aead, err := cipher.NewGCM(c) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &gcmCipher{ |
||||
aead: aead, |
||||
iv: iv, |
||||
}, nil |
||||
} |
||||
|
||||
const gcmTagSize = 16 |
||||
|
||||
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
// Pad out to multiple of 16 bytes. This is different from the
|
||||
// stream cipher because that encrypts the length too.
|
||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) |
||||
if padding < 4 { |
||||
padding += packetSizeMultiple |
||||
} |
||||
|
||||
length := uint32(len(packet) + int(padding) + 1) |
||||
binary.BigEndian.PutUint32(c.prefix[:], length) |
||||
if _, err := w.Write(c.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if cap(c.buf) < int(length) { |
||||
c.buf = make([]byte, length) |
||||
} else { |
||||
c.buf = c.buf[:length] |
||||
} |
||||
|
||||
c.buf[0] = padding |
||||
copy(c.buf[1:], packet) |
||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { |
||||
return err |
||||
} |
||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if _, err := w.Write(c.buf); err != nil { |
||||
return err |
||||
} |
||||
c.incIV() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *gcmCipher) incIV() { |
||||
for i := 4 + 7; i >= 4; i-- { |
||||
c.iv[i]++ |
||||
if c.iv[i] != 0 { |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
length := binary.BigEndian.Uint32(c.prefix[:]) |
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: max packet length exceeded.") |
||||
} |
||||
|
||||
if cap(c.buf) < int(length+gcmTagSize) { |
||||
c.buf = make([]byte, length+gcmTagSize) |
||||
} else { |
||||
c.buf = c.buf[:length+gcmTagSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, c.buf); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.incIV() |
||||
|
||||
padding := plain[0] |
||||
if padding < 4 || padding >= 20 { |
||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding) |
||||
} |
||||
|
||||
if int(padding+1) >= len(plain) { |
||||
return nil, fmt.Errorf("ssh: padding %d too large", padding) |
||||
} |
||||
plain = plain[1 : length-uint32(padding)] |
||||
return plain, nil |
||||
} |
||||
|
||||
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
||||
type cbcCipher struct { |
||||
mac hash.Hash |
||||
macSize uint32 |
||||
decrypter cipher.BlockMode |
||||
encrypter cipher.BlockMode |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
seqNumBytes [4]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
|
||||
// Amount of data we should still read to hide which
|
||||
// verification error triggered.
|
||||
oracleCamouflage uint32 |
||||
} |
||||
|
||||
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cbc := &cbcCipher{ |
||||
mac: macModes[algs.MAC].new(macKey), |
||||
decrypter: cipher.NewCBCDecrypter(c, iv), |
||||
encrypter: cipher.NewCBCEncrypter(c, iv), |
||||
packetData: make([]byte, 1024), |
||||
} |
||||
if cbc.mac != nil { |
||||
cbc.macSize = uint32(cbc.mac.Size()) |
||||
} |
||||
|
||||
return cbc, nil |
||||
} |
||||
|
||||
func maxUInt32(a, b int) uint32 { |
||||
if a > b { |
||||
return uint32(a) |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
const ( |
||||
cbcMinPacketSizeMultiple = 8 |
||||
cbcMinPacketSize = 16 |
||||
cbcMinPaddingSize = 4 |
||||
) |
||||
|
||||
// cbcError represents a verification error that may leak information.
|
||||
type cbcError string |
||||
|
||||
func (e cbcError) Error() string { return string(e) } |
||||
|
||||
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
p, err := c.readPacketLeaky(seqNum, r) |
||||
if err != nil { |
||||
if _, ok := err.(cbcError); ok { |
||||
// Verification error: read a fixed amount of
|
||||
// data, to make distinguishing between
|
||||
// failing MAC and failing length check more
|
||||
// difficult.
|
||||
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) |
||||
} |
||||
} |
||||
return p, err |
||||
} |
||||
|
||||
func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
blockSize := c.decrypter.BlockSize() |
||||
|
||||
// Read the header, which will include some of the subsequent data in the
|
||||
// case of block ciphers - this is copied back to the payload later.
|
||||
// How many bytes of payload/padding will be read with this first read.
|
||||
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) |
||||
firstBlock := c.packetData[:firstBlockLength] |
||||
if _, err := io.ReadFull(r, firstBlock); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength |
||||
|
||||
c.decrypter.CryptBlocks(firstBlock, firstBlock) |
||||
length := binary.BigEndian.Uint32(firstBlock[:4]) |
||||
if length > maxPacket { |
||||
return nil, cbcError("ssh: packet too large") |
||||
} |
||||
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { |
||||
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
||||
// is larger) bytes.
|
||||
return nil, cbcError("ssh: packet too small") |
||||
} |
||||
// The length of the packet (including the length field but not the MAC) must
|
||||
// be a multiple of the block size or 8, whichever is larger.
|
||||
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { |
||||
return nil, cbcError("ssh: invalid packet length multiple") |
||||
} |
||||
|
||||
paddingLength := uint32(firstBlock[4]) |
||||
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { |
||||
return nil, cbcError("ssh: invalid packet length") |
||||
} |
||||
|
||||
// Positions within the c.packetData buffer:
|
||||
macStart := 4 + length |
||||
paddingStart := macStart - paddingLength |
||||
|
||||
// Entire packet size, starting before length, ending at end of mac.
|
||||
entirePacketSize := macStart + c.macSize |
||||
|
||||
// Ensure c.packetData is large enough for the entire packet data.
|
||||
if uint32(cap(c.packetData)) < entirePacketSize { |
||||
// Still need to upsize and copy, but this should be rare at runtime, only
|
||||
// on upsizing the packetData buffer.
|
||||
c.packetData = make([]byte, entirePacketSize) |
||||
copy(c.packetData, firstBlock) |
||||
} else { |
||||
c.packetData = c.packetData[:entirePacketSize] |
||||
} |
||||
|
||||
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { |
||||
return nil, err |
||||
} else { |
||||
c.oracleCamouflage -= uint32(n) |
||||
} |
||||
|
||||
remainingCrypted := c.packetData[firstBlockLength:macStart] |
||||
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) |
||||
|
||||
mac := c.packetData[macStart:] |
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData[:macStart]) |
||||
c.macResult = c.mac.Sum(c.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { |
||||
return nil, cbcError("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return c.packetData[prefixLen:paddingStart], nil |
||||
} |
||||
|
||||
func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) |
||||
|
||||
// Length of encrypted portion of the packet (header, payload, padding).
|
||||
// Enforce minimum padding and packet size.
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) |
||||
// Enforce block size.
|
||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize |
||||
|
||||
length := encLength - 4 |
||||
paddingLength := int(length) - (1 + len(packet)) |
||||
|
||||
// Overall buffer contains: header, payload, padding, mac.
|
||||
// Space for the MAC is reserved in the capacity but not the slice length.
|
||||
bufferSize := encLength + c.macSize |
||||
if uint32(cap(c.packetData)) < bufferSize { |
||||
c.packetData = make([]byte, encLength, bufferSize) |
||||
} else { |
||||
c.packetData = c.packetData[:encLength] |
||||
} |
||||
|
||||
p := c.packetData |
||||
|
||||
// Packet header.
|
||||
binary.BigEndian.PutUint32(p, length) |
||||
p = p[4:] |
||||
p[0] = byte(paddingLength) |
||||
|
||||
// Payload.
|
||||
p = p[1:] |
||||
copy(p, packet) |
||||
|
||||
// Padding.
|
||||
p = p[len(packet):] |
||||
if _, err := io.ReadFull(rand, p); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData) |
||||
// The MAC is now appended into the capacity reserved for it earlier.
|
||||
c.packetData = c.mac.Sum(c.packetData) |
||||
} |
||||
|
||||
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) |
||||
|
||||
if _, err := w.Write(c.packetData); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -1,127 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto" |
||||
"crypto/aes" |
||||
"crypto/rand" |
||||
"testing" |
||||
) |
||||
|
||||
func TestDefaultCiphersExist(t *testing.T) { |
||||
for _, cipherAlgo := range supportedCiphers { |
||||
if _, ok := cipherModes[cipherAlgo]; !ok { |
||||
t.Errorf("default cipher %q is unknown", cipherAlgo) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPacketCiphers(t *testing.T) { |
||||
// Still test aes128cbc cipher althought it's commented out.
|
||||
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||
defer delete(cipherModes, aes128cbcID) |
||||
|
||||
for cipher := range cipherModes { |
||||
kr := &kexResult{Hash: crypto.SHA1} |
||||
algs := directionAlgorithms{ |
||||
Cipher: cipher, |
||||
MAC: "hmac-sha1", |
||||
Compression: "none", |
||||
} |
||||
client, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||
continue |
||||
} |
||||
server, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
want := "bla bla" |
||||
input := []byte(want) |
||||
buf := &bytes.Buffer{} |
||||
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||
t.Errorf("writePacket(%q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
packet, err := server.readPacket(0, buf) |
||||
if err != nil { |
||||
t.Errorf("readPacket(%q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
if string(packet) != want { |
||||
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestCBCOracleCounterMeasure(t *testing.T) { |
||||
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||
defer delete(cipherModes, aes128cbcID) |
||||
|
||||
kr := &kexResult{Hash: crypto.SHA1} |
||||
algs := directionAlgorithms{ |
||||
Cipher: aes128cbcID, |
||||
MAC: "hmac-sha1", |
||||
Compression: "none", |
||||
} |
||||
client, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Fatalf("newPacketCipher(client): %v", err) |
||||
} |
||||
|
||||
want := "bla bla" |
||||
input := []byte(want) |
||||
buf := &bytes.Buffer{} |
||||
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
packetSize := buf.Len() |
||||
buf.Write(make([]byte, 2*maxPacket)) |
||||
|
||||
// We corrupt each byte, but this usually will only test the
|
||||
// 'packet too large' or 'MAC failure' cases.
|
||||
lastRead := -1 |
||||
for i := 0; i < packetSize; i++ { |
||||
server, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Fatalf("newPacketCipher(client): %v", err) |
||||
} |
||||
|
||||
fresh := &bytes.Buffer{} |
||||
fresh.Write(buf.Bytes()) |
||||
fresh.Bytes()[i] ^= 0x01 |
||||
|
||||
before := fresh.Len() |
||||
_, err = server.readPacket(0, fresh) |
||||
if err == nil { |
||||
t.Errorf("corrupt byte %d: readPacket succeeded ", i) |
||||
continue |
||||
} |
||||
if _, ok := err.(cbcError); !ok { |
||||
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) |
||||
continue |
||||
} |
||||
|
||||
after := fresh.Len() |
||||
bytesRead := before - after |
||||
if bytesRead < maxPacket { |
||||
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) |
||||
continue |
||||
} |
||||
|
||||
if i > 0 && bytesRead != lastRead { |
||||
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) |
||||
} |
||||
lastRead = bytesRead |
||||
} |
||||
} |
@ -1,213 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
) |
||||
|
||||
// Client implements a traditional SSH client that supports shells,
|
||||
// subprocesses, port forwarding and tunneled dialing.
|
||||
type Client struct { |
||||
Conn |
||||
|
||||
forwards forwardList // forwarded tcpip connections from the remote side
|
||||
mu sync.Mutex |
||||
channelHandlers map[string]chan NewChannel |
||||
} |
||||
|
||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
||||
// for the given type are sent. If the type already is being handled,
|
||||
// nil is returned. The channel is closed when the connection is closed.
|
||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.channelHandlers == nil { |
||||
// The SSH channel has been closed.
|
||||
c := make(chan NewChannel) |
||||
close(c) |
||||
return c |
||||
} |
||||
|
||||
ch := c.channelHandlers[channelType] |
||||
if ch != nil { |
||||
return nil |
||||
} |
||||
|
||||
ch = make(chan NewChannel, 16) |
||||
c.channelHandlers[channelType] = ch |
||||
return ch |
||||
} |
||||
|
||||
// NewClient creates a Client on top of the given connection.
|
||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { |
||||
conn := &Client{ |
||||
Conn: c, |
||||
channelHandlers: make(map[string]chan NewChannel, 1), |
||||
} |
||||
|
||||
go conn.handleGlobalRequests(reqs) |
||||
go conn.handleChannelOpens(chans) |
||||
go func() { |
||||
conn.Wait() |
||||
conn.forwards.closeAll() |
||||
}() |
||||
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) |
||||
return conn |
||||
} |
||||
|
||||
// NewClientConn establishes an authenticated SSH connection using c
|
||||
// as the underlying transport. The Request and NewChannel channels
|
||||
// must be serviced or the connection will hang.
|
||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { |
||||
fullConf := *config |
||||
fullConf.SetDefaults() |
||||
conn := &connection{ |
||||
sshConn: sshConn{conn: c}, |
||||
} |
||||
|
||||
if err := conn.clientHandshake(addr, &fullConf); err != nil { |
||||
c.Close() |
||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) |
||||
} |
||||
conn.mux = newMux(conn.transport) |
||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil |
||||
} |
||||
|
||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
||||
// 7.
|
||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { |
||||
if config.ClientVersion != "" { |
||||
c.clientVersion = []byte(config.ClientVersion) |
||||
} else { |
||||
c.clientVersion = []byte(packageVersion) |
||||
} |
||||
var err error |
||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
c.transport = newClientTransport( |
||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */), |
||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) |
||||
if err := c.transport.requestKeyChange(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if packet, err := c.transport.readPacket(); err != nil { |
||||
return err |
||||
} else if packet[0] != msgNewKeys { |
||||
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
|
||||
// We just did the key change, so the session ID is established.
|
||||
c.sessionID = c.transport.getSessionID() |
||||
|
||||
return c.clientAuthenticate(config) |
||||
} |
||||
|
||||
// verifyHostKeySignature verifies the host key obtained in the key
|
||||
// exchange.
|
||||
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { |
||||
sig, rest, ok := parseSignatureBody(result.Signature) |
||||
if len(rest) > 0 || !ok { |
||||
return errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return hostKey.Verify(result.H, sig) |
||||
} |
||||
|
||||
// NewSession opens a new Session for this client. (A session is a remote
|
||||
// execution of a program.)
|
||||
func (c *Client) NewSession() (*Session, error) { |
||||
ch, in, err := c.OpenChannel("session", nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return newSession(ch, in) |
||||
} |
||||
|
||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) { |
||||
for r := range incoming { |
||||
// This handles keepalive messages and matches
|
||||
// the behaviour of OpenSSH.
|
||||
r.Reply(false, nil) |
||||
} |
||||
} |
||||
|
||||
// handleChannelOpens channel open messages from the remote side.
|
||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) { |
||||
for ch := range in { |
||||
c.mu.Lock() |
||||
handler := c.channelHandlers[ch.ChannelType()] |
||||
c.mu.Unlock() |
||||
|
||||
if handler != nil { |
||||
handler <- ch |
||||
} else { |
||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) |
||||
} |
||||
} |
||||
|
||||
c.mu.Lock() |
||||
for _, ch := range c.channelHandlers { |
||||
close(ch) |
||||
} |
||||
c.channelHandlers = nil |
||||
c.mu.Unlock() |
||||
} |
||||
|
||||
// Dial starts a client connection to the given SSH server. It is a
|
||||
// convenience function that connects to the given network address,
|
||||
// initiates the SSH handshake, and then sets up a Client. For access
|
||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
||||
// instead.
|
||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) { |
||||
conn, err := net.Dial(network, addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c, chans, reqs, err := NewClientConn(conn, addr, config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return NewClient(c, chans, reqs), nil |
||||
} |
||||
|
||||
// A ClientConfig structure is used to configure a Client. It must not be
|
||||
// modified after having been passed to an SSH function.
|
||||
type ClientConfig struct { |
||||
// Config contains configuration that is shared between clients and
|
||||
// servers.
|
||||
Config |
||||
|
||||
// User contains the username to authenticate as.
|
||||
User string |
||||
|
||||
// Auth contains possible authentication methods to use with the
|
||||
// server. Only the first instance of a particular RFC 4252 method will
|
||||
// be used during authentication.
|
||||
Auth []AuthMethod |
||||
|
||||
// HostKeyCallback, if not nil, is called during the cryptographic
|
||||
// handshake to validate the server's host key. A nil HostKeyCallback
|
||||
// implies that all host keys are accepted.
|
||||
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
|
||||
// ClientVersion contains the version identification string that will
|
||||
// be used for the connection. If empty, a reasonable default is used.
|
||||
ClientVersion string |
||||
|
||||
// HostKeyAlgorithms lists the key types that the client will
|
||||
// accept from the server as host key, in order of
|
||||
// preference. If empty, a reasonable default is used. Any
|
||||
// string returned from PublicKey.Type method may be used, or
|
||||
// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
|
||||
HostKeyAlgorithms []string |
||||
} |
@ -1,441 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
||||
func (c *connection) clientAuthenticate(config *ClientConfig) error { |
||||
// initiate user auth session
|
||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { |
||||
return err |
||||
} |
||||
packet, err := c.transport.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var serviceAccept serviceAcceptMsg |
||||
if err := Unmarshal(packet, &serviceAccept); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// during the authentication phase the client first attempts the "none" method
|
||||
// then any untried methods suggested by the server.
|
||||
tried := make(map[string]bool) |
||||
var lastMethods []string |
||||
for auth := AuthMethod(new(noneAuth)); auth != nil; { |
||||
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if ok { |
||||
// success
|
||||
return nil |
||||
} |
||||
tried[auth.method()] = true |
||||
if methods == nil { |
||||
methods = lastMethods |
||||
} |
||||
lastMethods = methods |
||||
|
||||
auth = nil |
||||
|
||||
findNext: |
||||
for _, a := range config.Auth { |
||||
candidateMethod := a.method() |
||||
if tried[candidateMethod] { |
||||
continue |
||||
} |
||||
for _, meth := range methods { |
||||
if meth == candidateMethod { |
||||
auth = a |
||||
break findNext |
||||
} |
||||
} |
||||
} |
||||
} |
||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) |
||||
} |
||||
|
||||
func keys(m map[string]bool) []string { |
||||
s := make([]string, 0, len(m)) |
||||
|
||||
for key := range m { |
||||
s = append(s, key) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
||||
type AuthMethod interface { |
||||
// auth authenticates user over transport t.
|
||||
// Returns true if authentication is successful.
|
||||
// If authentication is not successful, a []string of alternative
|
||||
// method names is returned. If the slice is nil, it will be ignored
|
||||
// and the previous set of possible methods will be reused.
|
||||
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) |
||||
|
||||
// method returns the RFC 4252 method name.
|
||||
method() string |
||||
} |
||||
|
||||
// "none" authentication, RFC 4252 section 5.2.
|
||||
type noneAuth int |
||||
|
||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "none", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (n *noneAuth) method() string { |
||||
return "none" |
||||
} |
||||
|
||||
// passwordCallback is an AuthMethod that fetches the password through
|
||||
// a function call, e.g. by prompting the user.
|
||||
type passwordCallback func() (password string, err error) |
||||
|
||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type passwordAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Reply bool |
||||
Password string |
||||
} |
||||
|
||||
pw, err := cb() |
||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
||||
// The program may only find out that the user doesn't have a password
|
||||
// when prompting.
|
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&passwordAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
Reply: false, |
||||
Password: pw, |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (cb passwordCallback) method() string { |
||||
return "password" |
||||
} |
||||
|
||||
// Password returns an AuthMethod using the given password.
|
||||
func Password(secret string) AuthMethod { |
||||
return passwordCallback(func() (string, error) { return secret, nil }) |
||||
} |
||||
|
||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
||||
// fetching a password.
|
||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { |
||||
return passwordCallback(prompt) |
||||
} |
||||
|
||||
type publickeyAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
||||
// should be used for authentication of the request.
|
||||
HasSig bool |
||||
Algoname string |
||||
PubKey []byte |
||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
||||
// validateKey
|
||||
Sig []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
||||
// pairs for authentication.
|
||||
type publicKeyCallback func() ([]Signer, error) |
||||
|
||||
func (cb publicKeyCallback) method() string { |
||||
return "publickey" |
||||
} |
||||
|
||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
// Authentication is performed in two stages. The first stage sends an
|
||||
// enquiry to test if each key is acceptable to the remote. The second
|
||||
// stage attempts to authenticate with the valid keys obtained in the
|
||||
// first stage.
|
||||
|
||||
signers, err := cb() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
var validKeys []Signer |
||||
for _, signer := range signers { |
||||
if ok, err := validateKey(signer.PublicKey(), user, c); ok { |
||||
validKeys = append(validKeys, signer) |
||||
} else { |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// methods that may continue if this auth is not successful.
|
||||
var methods []string |
||||
for _, signer := range validKeys { |
||||
pub := signer.PublicKey() |
||||
|
||||
pubKey := pub.Marshal() |
||||
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
}, []byte(pub.Type()), pubKey)) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// manually wrap the serialized signature in a string
|
||||
s := Marshal(sign) |
||||
sig := make([]byte, stringLength(len(s))) |
||||
marshalString(sig, s) |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
HasSig: true, |
||||
Algoname: pub.Type(), |
||||
PubKey: pubKey, |
||||
Sig: sig, |
||||
} |
||||
p := Marshal(&msg) |
||||
if err := c.writePacket(p); err != nil { |
||||
return false, nil, err |
||||
} |
||||
var success bool |
||||
success, methods, err = handleAuthResponse(c) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
if success { |
||||
return success, methods, err |
||||
} |
||||
} |
||||
return false, methods, nil |
||||
} |
||||
|
||||
// validateKey validates the key provided is acceptable to the server.
|
||||
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "publickey", |
||||
HasSig: false, |
||||
Algoname: key.Type(), |
||||
PubKey: pubKey, |
||||
} |
||||
if err := c.writePacket(Marshal(&msg)); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
return confirmKeyAck(key, c) |
||||
} |
||||
|
||||
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
algoname := key.Type() |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO(gpaul): add callback to present the banner to the user
|
||||
case msgUserAuthPubKeyOk: |
||||
var msg userAuthPubKeyOkMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, err |
||||
} |
||||
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { |
||||
return false, nil |
||||
} |
||||
return true, nil |
||||
case msgUserAuthFailure: |
||||
return false, nil |
||||
default: |
||||
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// PublicKeys returns an AuthMethod that uses the given key
|
||||
// pairs.
|
||||
func PublicKeys(signers ...Signer) AuthMethod { |
||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) |
||||
} |
||||
|
||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
||||
// function to obtain a list of key pairs.
|
||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { |
||||
return publicKeyCallback(getSigners) |
||||
} |
||||
|
||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
||||
// along with a list of remaining authentication methods to try next and
|
||||
// an error if an unexpected response was received.
|
||||
func handleAuthResponse(c packetConn) (bool, []string, error) { |
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: add callback to present the banner to the user
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
case msgDisconnect: |
||||
return false, nil, io.EOF |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// KeyboardInteractiveChallenge should print questions, optionally
|
||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
||||
// Challenge may be called multiple times in a single session. After
|
||||
// successful authentication, the server may send a challenge with no
|
||||
// questions, for which the user and instruction messages should be
|
||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
||||
// both CLI and GUI environments.
|
||||
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) |
||||
|
||||
// KeyboardInteractive returns a AuthMethod using a prompt/response
|
||||
// sequence controlled by the server.
|
||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { |
||||
return challenge |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) method() string { |
||||
return "keyboard-interactive" |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type initiateMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Language string |
||||
Submethods string |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&initiateMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "keyboard-interactive", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// like handleAuthResponse, but with less options.
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: Print banners during userauth.
|
||||
continue |
||||
case msgUserAuthInfoRequest: |
||||
// OK
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) |
||||
} |
||||
|
||||
var msg userAuthInfoRequestMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// Manually unpack the prompt/echo pairs.
|
||||
rest := msg.Prompts |
||||
var prompts []string |
||||
var echos []bool |
||||
for i := 0; i < int(msg.NumPrompts); i++ { |
||||
prompt, r, ok := parseString(rest) |
||||
if !ok || len(r) == 0 { |
||||
return false, nil, errors.New("ssh: prompt format error") |
||||
} |
||||
prompts = append(prompts, string(prompt)) |
||||
echos = append(echos, r[0] != 0) |
||||
rest = r[1:] |
||||
} |
||||
|
||||
if len(rest) != 0 { |
||||
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") |
||||
} |
||||
|
||||
answers, err := cb(msg.User, msg.Instruction, prompts, echos) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if len(answers) != len(prompts) { |
||||
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") |
||||
} |
||||
responseLength := 1 + 4 |
||||
for _, a := range answers { |
||||
responseLength += stringLength(len(a)) |
||||
} |
||||
serialized := make([]byte, responseLength) |
||||
p := serialized |
||||
p[0] = msgUserAuthInfoResponse |
||||
p = p[1:] |
||||
p = marshalUint32(p, uint32(len(answers))) |
||||
for _, a := range answers { |
||||
p = marshalString(p, []byte(a)) |
||||
} |
||||
|
||||
if err := c.writePacket(serialized); err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
@ -1,393 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
type keyboardInteractive map[string]string |
||||
|
||||
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { |
||||
var answers []string |
||||
for _, q := range questions { |
||||
answers = append(answers, cr[q]) |
||||
} |
||||
return answers, nil |
||||
} |
||||
|
||||
// reused internally by tests
|
||||
var clientPassword = "tiger" |
||||
|
||||
// tryAuth runs a handshake with a given config against an SSH server
|
||||
// with config serverConfig
|
||||
func tryAuth(t *testing.T, config *ClientConfig) error { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
certChecker := CertChecker{ |
||||
IsAuthority: func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) |
||||
}, |
||||
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||
return nil, nil |
||||
} |
||||
|
||||
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) |
||||
}, |
||||
IsRevoked: func(c *Certificate) bool { |
||||
return c.Serial == 666 |
||||
}, |
||||
} |
||||
|
||||
serverConfig := &ServerConfig{ |
||||
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { |
||||
if conn.User() == "testuser" && string(pass) == clientPassword { |
||||
return nil, nil |
||||
} |
||||
return nil, errors.New("password auth failed") |
||||
}, |
||||
PublicKeyCallback: certChecker.Authenticate, |
||||
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { |
||||
ans, err := challenge("user", |
||||
"instruction", |
||||
[]string{"question1", "question2"}, |
||||
[]bool{true, true}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" |
||||
if ok { |
||||
challenge("user", "motd", nil, nil) |
||||
return nil, nil |
||||
} |
||||
return nil, errors.New("keyboard-interactive failed") |
||||
}, |
||||
AuthLogCallback: func(conn ConnMetadata, method string, err error) { |
||||
t.Logf("user %q, method %q: %v", conn.User(), method, err) |
||||
}, |
||||
} |
||||
serverConfig.AddHostKey(testSigners["rsa"]) |
||||
|
||||
go newServer(c1, serverConfig) |
||||
_, _, _, err = NewClientConn(c2, "", config) |
||||
return err |
||||
} |
||||
|
||||
func TestClientAuthPublicKey(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodPassword(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
Password(clientPassword), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodFallback(t *testing.T) { |
||||
var passwordCalled bool |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
PasswordCallback( |
||||
func() (string, error) { |
||||
passwordCalled = true |
||||
return "WRONG", nil |
||||
}), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
|
||||
if passwordCalled { |
||||
t.Errorf("password auth tried before public-key auth.") |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodWrongPassword(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
Password("wrong"), |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodKeyboardInteractive(t *testing.T) { |
||||
answers := keyboardInteractive(map[string]string{ |
||||
"question1": "answer1", |
||||
"question2": "answer2", |
||||
}) |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
KeyboardInteractive(answers.Challenge), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { |
||||
answers := keyboardInteractive(map[string]string{ |
||||
"question1": "answer1", |
||||
"question2": "WRONG", |
||||
}) |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
KeyboardInteractive(answers.Challenge), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") |
||||
} |
||||
} |
||||
|
||||
// the mock server will only authenticate ssh-rsa keys
|
||||
func TestAuthMethodInvalidPublicKey(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["dsa"]), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Fatalf("dsa private key should not have authenticated with rsa public key") |
||||
} |
||||
} |
||||
|
||||
// the client should authenticate with the second key
|
||||
func TestAuthMethodRSAandDSA(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["dsa"], testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("client could not authenticate with rsa key: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestClientHMAC(t *testing.T) { |
||||
for _, mac := range supportedMACs { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
Config: Config{ |
||||
MACs: []string{mac}, |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// issue 4285.
|
||||
func TestClientUnsupportedCipher(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(), |
||||
}, |
||||
Config: Config{ |
||||
Ciphers: []string{"aes128-cbc"}, // not currently supported
|
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Errorf("expected no ciphers in common") |
||||
} |
||||
} |
||||
|
||||
func TestClientUnsupportedKex(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(), |
||||
}, |
||||
Config: Config{ |
||||
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
|
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { |
||||
t.Errorf("got %v, expected 'common algorithm'", err) |
||||
} |
||||
} |
||||
|
||||
func TestClientLoginCert(t *testing.T) { |
||||
cert := &Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
CertType: UserCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||
if err != nil { |
||||
t.Fatalf("NewCertSigner: %v", err) |
||||
} |
||||
|
||||
clientConfig := &ClientConfig{ |
||||
User: "user", |
||||
} |
||||
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) |
||||
|
||||
t.Log("should succeed") |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login failed: %v", err) |
||||
} |
||||
|
||||
t.Log("corrupted signature") |
||||
cert.Signature.Blob[0]++ |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with corrupted sig") |
||||
} |
||||
|
||||
t.Log("revoked") |
||||
cert.Serial = 666 |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("revoked cert login succeeded") |
||||
} |
||||
cert.Serial = 1 |
||||
|
||||
t.Log("sign with wrong key") |
||||
cert.SignCert(rand.Reader, testSigners["dsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with non-authoritive key") |
||||
} |
||||
|
||||
t.Log("host cert") |
||||
cert.CertType = HostCert |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with wrong type") |
||||
} |
||||
cert.CertType = UserCert |
||||
|
||||
t.Log("principal specified") |
||||
cert.ValidPrincipals = []string{"user"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login failed: %v", err) |
||||
} |
||||
|
||||
t.Log("wrong principal specified") |
||||
cert.ValidPrincipals = []string{"fred"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with wrong principal") |
||||
} |
||||
cert.ValidPrincipals = nil |
||||
|
||||
t.Log("added critical option") |
||||
cert.CriticalOptions = map[string]string{"root-access": "yes"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with unrecognized critical option") |
||||
} |
||||
|
||||
t.Log("allowed source address") |
||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login with source-address failed: %v", err) |
||||
} |
||||
|
||||
t.Log("disallowed source address") |
||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login with source-address succeeded") |
||||
} |
||||
} |
||||
|
||||
func testPermissionsPassing(withPermissions bool, t *testing.T) { |
||||
serverConfig := &ServerConfig{ |
||||
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||
if conn.User() == "nopermissions" { |
||||
return nil, nil |
||||
} else { |
||||
return &Permissions{}, nil |
||||
} |
||||
}, |
||||
} |
||||
serverConfig.AddHostKey(testSigners["rsa"]) |
||||
|
||||
clientConfig := &ClientConfig{ |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if withPermissions { |
||||
clientConfig.User = "permissions" |
||||
} else { |
||||
clientConfig.User = "nopermissions" |
||||
} |
||||
|
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
go NewClientConn(c2, "", clientConfig) |
||||
serverConn, err := newServer(c1, serverConfig) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if p := serverConn.Permissions; (p != nil) != withPermissions { |
||||
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) |
||||
} |
||||
} |
||||
|
||||
func TestPermissionsPassing(t *testing.T) { |
||||
testPermissionsPassing(true, t) |
||||
} |
||||
|
||||
func TestNoPermissionsPassing(t *testing.T) { |
||||
testPermissionsPassing(false, t) |
||||
} |
@ -1,39 +0,0 @@ |
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"net" |
||||
"testing" |
||||
) |
||||
|
||||
func testClientVersion(t *testing.T, config *ClientConfig, expected string) { |
||||
clientConn, serverConn := net.Pipe() |
||||
defer clientConn.Close() |
||||
receivedVersion := make(chan string, 1) |
||||
go func() { |
||||
version, err := readVersion(serverConn) |
||||
if err != nil { |
||||
receivedVersion <- "" |
||||
} else { |
||||
receivedVersion <- string(version) |
||||
} |
||||
serverConn.Close() |
||||
}() |
||||
NewClientConn(clientConn, "", config) |
||||
actual := <-receivedVersion |
||||
if actual != expected { |
||||
t.Fatalf("got %s; want %s", actual, expected) |
||||
} |
||||
} |
||||
|
||||
func TestCustomClientVersion(t *testing.T) { |
||||
version := "Test-Client-Version-0.0" |
||||
testClientVersion(t, &ClientConfig{ClientVersion: version}, version) |
||||
} |
||||
|
||||
func TestDefaultClientVersion(t *testing.T) { |
||||
testClientVersion(t, &ClientConfig{}, packageVersion) |
||||
} |
@ -1,354 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/rand" |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
|
||||
_ "crypto/sha1" |
||||
_ "crypto/sha256" |
||||
_ "crypto/sha512" |
||||
) |
||||
|
||||
// These are string constants in the SSH protocol.
|
||||
const ( |
||||
compressionNone = "none" |
||||
serviceUserAuth = "ssh-userauth" |
||||
serviceSSH = "ssh-connection" |
||||
) |
||||
|
||||
// supportedCiphers specifies the supported ciphers in preference order.
|
||||
var supportedCiphers = []string{ |
||||
"aes128-ctr", "aes192-ctr", "aes256-ctr", |
||||
"aes128-gcm@openssh.com", |
||||
"arcfour256", "arcfour128", |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
||||
// preference order.
|
||||
var supportedKexAlgos = []string{ |
||||
kexAlgoCurve25519SHA256, |
||||
// P384 and P521 are not constant-time yet, but since we don't
|
||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, |
||||
kexAlgoDH14SHA1, kexAlgoDH1SHA1, |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
|
||||
// of authenticating servers) in preference order.
|
||||
var supportedHostKeyAlgos = []string{ |
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, |
||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01, |
||||
|
||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||
KeyAlgoRSA, KeyAlgoDSA, |
||||
} |
||||
|
||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
||||
// because they have reached the end of their useful life.
|
||||
var supportedMACs = []string{ |
||||
"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", |
||||
} |
||||
|
||||
var supportedCompressions = []string{compressionNone} |
||||
|
||||
// hashFuncs keeps the mapping of supported algorithms to their respective
|
||||
// hashes needed for signature verification.
|
||||
var hashFuncs = map[string]crypto.Hash{ |
||||
KeyAlgoRSA: crypto.SHA1, |
||||
KeyAlgoDSA: crypto.SHA1, |
||||
KeyAlgoECDSA256: crypto.SHA256, |
||||
KeyAlgoECDSA384: crypto.SHA384, |
||||
KeyAlgoECDSA521: crypto.SHA512, |
||||
CertAlgoRSAv01: crypto.SHA1, |
||||
CertAlgoDSAv01: crypto.SHA1, |
||||
CertAlgoECDSA256v01: crypto.SHA256, |
||||
CertAlgoECDSA384v01: crypto.SHA384, |
||||
CertAlgoECDSA521v01: crypto.SHA512, |
||||
} |
||||
|
||||
// unexpectedMessageError results when the SSH message that we received didn't
|
||||
// match what we wanted.
|
||||
func unexpectedMessageError(expected, got uint8) error { |
||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) |
||||
} |
||||
|
||||
// parseError results from a malformed SSH message.
|
||||
func parseError(tag uint8) error { |
||||
return fmt.Errorf("ssh: parse error in message type %d", tag) |
||||
} |
||||
|
||||
func findCommon(what string, client []string, server []string) (common string, err error) { |
||||
for _, c := range client { |
||||
for _, s := range server { |
||||
if c == s { |
||||
return c, nil |
||||
} |
||||
} |
||||
} |
||||
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) |
||||
} |
||||
|
||||
type directionAlgorithms struct { |
||||
Cipher string |
||||
MAC string |
||||
Compression string |
||||
} |
||||
|
||||
type algorithms struct { |
||||
kex string |
||||
hostKey string |
||||
w directionAlgorithms |
||||
r directionAlgorithms |
||||
} |
||||
|
||||
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { |
||||
result := &algorithms{} |
||||
|
||||
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// If rekeythreshold is too small, we can't make any progress sending
|
||||
// stuff.
|
||||
const minRekeyThreshold uint64 = 256 |
||||
|
||||
// Config contains configuration data common to both ServerConfig and
|
||||
// ClientConfig.
|
||||
type Config struct { |
||||
// Rand provides the source of entropy for cryptographic
|
||||
// primitives. If Rand is nil, the cryptographic random reader
|
||||
// in package crypto/rand will be used.
|
||||
Rand io.Reader |
||||
|
||||
// The maximum number of bytes sent or received after which a
|
||||
// new key is negotiated. It must be at least 256. If
|
||||
// unspecified, 1 gigabyte is used.
|
||||
RekeyThreshold uint64 |
||||
|
||||
// The allowed key exchanges algorithms. If unspecified then a
|
||||
// default set of algorithms is used.
|
||||
KeyExchanges []string |
||||
|
||||
// The allowed cipher algorithms. If unspecified then a sensible
|
||||
// default is used.
|
||||
Ciphers []string |
||||
|
||||
// The allowed MAC algorithms. If unspecified then a sensible default
|
||||
// is used.
|
||||
MACs []string |
||||
} |
||||
|
||||
// SetDefaults sets sensible values for unset fields in config. This is
|
||||
// exported for testing: Configs passed to SSH functions are copied and have
|
||||
// default values set automatically.
|
||||
func (c *Config) SetDefaults() { |
||||
if c.Rand == nil { |
||||
c.Rand = rand.Reader |
||||
} |
||||
if c.Ciphers == nil { |
||||
c.Ciphers = supportedCiphers |
||||
} |
||||
var ciphers []string |
||||
for _, c := range c.Ciphers { |
||||
if cipherModes[c] != nil { |
||||
// reject the cipher if we have no cipherModes definition
|
||||
ciphers = append(ciphers, c) |
||||
} |
||||
} |
||||
c.Ciphers = ciphers |
||||
|
||||
if c.KeyExchanges == nil { |
||||
c.KeyExchanges = supportedKexAlgos |
||||
} |
||||
|
||||
if c.MACs == nil { |
||||
c.MACs = supportedMACs |
||||
} |
||||
|
||||
if c.RekeyThreshold == 0 { |
||||
// RFC 4253, section 9 suggests rekeying after 1G.
|
||||
c.RekeyThreshold = 1 << 30 |
||||
} |
||||
if c.RekeyThreshold < minRekeyThreshold { |
||||
c.RekeyThreshold = minRekeyThreshold |
||||
} |
||||
} |
||||
|
||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||
// possession of a private key. See RFC 4252, section 7.
|
||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { |
||||
data := struct { |
||||
Session []byte |
||||
Type byte |
||||
User string |
||||
Service string |
||||
Method string |
||||
Sign bool |
||||
Algo []byte |
||||
PubKey []byte |
||||
}{ |
||||
sessionId, |
||||
msgUserAuthRequest, |
||||
req.User, |
||||
req.Service, |
||||
req.Method, |
||||
true, |
||||
algo, |
||||
pubKey, |
||||
} |
||||
return Marshal(data) |
||||
} |
||||
|
||||
func appendU16(buf []byte, n uint16) []byte { |
||||
return append(buf, byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU32(buf []byte, n uint32) []byte { |
||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU64(buf []byte, n uint64) []byte { |
||||
return append(buf, |
||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), |
||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendInt(buf []byte, n int) []byte { |
||||
return appendU32(buf, uint32(n)) |
||||
} |
||||
|
||||
func appendString(buf []byte, s string) []byte { |
||||
buf = appendU32(buf, uint32(len(s))) |
||||
buf = append(buf, s...) |
||||
return buf |
||||
} |
||||
|
||||
func appendBool(buf []byte, b bool) []byte { |
||||
if b { |
||||
return append(buf, 1) |
||||
} |
||||
return append(buf, 0) |
||||
} |
||||
|
||||
// newCond is a helper to hide the fact that there is no usable zero
|
||||
// value for sync.Cond.
|
||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } |
||||
|
||||
// window represents the buffer available to clients
|
||||
// wishing to write to a channel.
|
||||
type window struct { |
||||
*sync.Cond |
||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||
writeWaiters int |
||||
closed bool |
||||
} |
||||
|
||||
// add adds win to the amount of window available
|
||||
// for consumers.
|
||||
func (w *window) add(win uint32) bool { |
||||
// a zero sized window adjust is a noop.
|
||||
if win == 0 { |
||||
return true |
||||
} |
||||
w.L.Lock() |
||||
if w.win+win < win { |
||||
w.L.Unlock() |
||||
return false |
||||
} |
||||
w.win += win |
||||
// It is unusual that multiple goroutines would be attempting to reserve
|
||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||
// that additional window is available.
|
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
return true |
||||
} |
||||
|
||||
// close sets the window to closed, so all reservations fail
|
||||
// immediately.
|
||||
func (w *window) close() { |
||||
w.L.Lock() |
||||
w.closed = true |
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
} |
||||
|
||||
// reserve reserves win from the available window capacity.
|
||||
// If no capacity remains, reserve will block. reserve may
|
||||
// return less than requested.
|
||||
func (w *window) reserve(win uint32) (uint32, error) { |
||||
var err error |
||||
w.L.Lock() |
||||
w.writeWaiters++ |
||||
w.Broadcast() |
||||
for w.win == 0 && !w.closed { |
||||
w.Wait() |
||||
} |
||||
w.writeWaiters-- |
||||
if w.win < win { |
||||
win = w.win |
||||
} |
||||
w.win -= win |
||||
if w.closed { |
||||
err = io.EOF |
||||
} |
||||
w.L.Unlock() |
||||
return win, err |
||||
} |
||||
|
||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
||||
// writes. It is used in tests only.
|
||||
func (w *window) waitWriterBlocked() { |
||||
w.Cond.L.Lock() |
||||
for w.writeWaiters == 0 { |
||||
w.Cond.Wait() |
||||
} |
||||
w.Cond.L.Unlock() |
||||
} |
@ -1,144 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
) |
||||
|
||||
// OpenChannelError is returned if the other side rejects an
|
||||
// OpenChannel request.
|
||||
type OpenChannelError struct { |
||||
Reason RejectionReason |
||||
Message string |
||||
} |
||||
|
||||
func (e *OpenChannelError) Error() string { |
||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
||||
} |
||||
|
||||
// ConnMetadata holds metadata for the connection.
|
||||
type ConnMetadata interface { |
||||
// User returns the user ID for this connection.
|
||||
// It is empty if no authentication is used.
|
||||
User() string |
||||
|
||||
// SessionID returns the sesson hash, also denoted by H.
|
||||
SessionID() []byte |
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
ClientVersion() []byte |
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
ServerVersion() []byte |
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
RemoteAddr() net.Addr |
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
LocalAddr() net.Addr |
||||
} |
||||
|
||||
// Conn represents an SSH connection for both server and client roles.
|
||||
// Conn is the basis for implementing an application layer, such
|
||||
// as ClientConn, which implements the traditional shell access for
|
||||
// clients.
|
||||
type Conn interface { |
||||
ConnMetadata |
||||
|
||||
// SendRequest sends a global request, and returns the
|
||||
// reply. If wantReply is true, it returns the response status
|
||||
// and payload. See also RFC4254, section 4.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) |
||||
|
||||
// OpenChannel tries to open an channel. If the request is
|
||||
// rejected, it returns *OpenChannelError. On success it returns
|
||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
||||
// requests. The Go channel must be serviced, or the
|
||||
// connection will hang.
|
||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) |
||||
|
||||
// Close closes the underlying network connection
|
||||
Close() error |
||||
|
||||
// Wait blocks until the connection has shut down, and returns the
|
||||
// error causing the shutdown.
|
||||
Wait() error |
||||
|
||||
// TODO(hanwen): consider exposing:
|
||||
// RequestKeyChange
|
||||
// Disconnect
|
||||
} |
||||
|
||||
// DiscardRequests consumes and rejects all requests from the
|
||||
// passed-in channel.
|
||||
func DiscardRequests(in <-chan *Request) { |
||||
for req := range in { |
||||
if req.WantReply { |
||||
req.Reply(false, nil) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// A connection represents an incoming connection.
|
||||
type connection struct { |
||||
transport *handshakeTransport |
||||
sshConn |
||||
|
||||
// The connection protocol.
|
||||
*mux |
||||
} |
||||
|
||||
func (c *connection) Close() error { |
||||
return c.sshConn.conn.Close() |
||||
} |
||||
|
||||
// sshconn provides net.Conn metadata, but disallows direct reads and
|
||||
// writes.
|
||||
type sshConn struct { |
||||
conn net.Conn |
||||
|
||||
user string |
||||
sessionID []byte |
||||
clientVersion []byte |
||||
serverVersion []byte |
||||
} |
||||
|
||||
func dup(src []byte) []byte { |
||||
dst := make([]byte, len(src)) |
||||
copy(dst, src) |
||||
return dst |
||||
} |
||||
|
||||
func (c *sshConn) User() string { |
||||
return c.user |
||||
} |
||||
|
||||
func (c *sshConn) RemoteAddr() net.Addr { |
||||
return c.conn.RemoteAddr() |
||||
} |
||||
|
||||
func (c *sshConn) Close() error { |
||||
return c.conn.Close() |
||||
} |
||||
|
||||
func (c *sshConn) LocalAddr() net.Addr { |
||||
return c.conn.LocalAddr() |
||||
} |
||||
|
||||
func (c *sshConn) SessionID() []byte { |
||||
return dup(c.sessionID) |
||||
} |
||||
|
||||
func (c *sshConn) ClientVersion() []byte { |
||||
return dup(c.clientVersion) |
||||
} |
||||
|
||||
func (c *sshConn) ServerVersion() []byte { |
||||
return dup(c.serverVersion) |
||||
} |
@ -1,18 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* |
||||
Package ssh implements an SSH client and server. |
||||
|
||||
SSH is a transport security protocol, an authentication protocol and a |
||||
family of application protocols. The most typical application level |
||||
protocol is a remote shell and this is specifically implemented. However, |
||||
the multiplexed nature of SSH is exposed to users that wish to support |
||||
others. |
||||
|
||||
References: |
||||
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||
*/ |
||||
package ssh |
@ -1,211 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh_test |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"log" |
||||
"net" |
||||
"net/http" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
"github.com/gogits/gogs/modules/crypto/ssh/terminal" |
||||
) |
||||
|
||||
func ExampleNewServerConn() { |
||||
// An SSH server is represented by a ServerConfig, which holds
|
||||
// certificate details and handles authentication of ServerConns.
|
||||
config := &ssh.ServerConfig{ |
||||
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { |
||||
// Should use constant-time compare (or better, salt+hash) in
|
||||
// a production setting.
|
||||
if c.User() == "testuser" && string(pass) == "tiger" { |
||||
return nil, nil |
||||
} |
||||
return nil, fmt.Errorf("password rejected for %q", c.User()) |
||||
}, |
||||
} |
||||
|
||||
privateBytes, err := ioutil.ReadFile("id_rsa") |
||||
if err != nil { |
||||
panic("Failed to load private key") |
||||
} |
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes) |
||||
if err != nil { |
||||
panic("Failed to parse private key") |
||||
} |
||||
|
||||
config.AddHostKey(private) |
||||
|
||||
// Once a ServerConfig has been configured, connections can be
|
||||
// accepted.
|
||||
listener, err := net.Listen("tcp", "0.0.0.0:2022") |
||||
if err != nil { |
||||
panic("failed to listen for connection") |
||||
} |
||||
nConn, err := listener.Accept() |
||||
if err != nil { |
||||
panic("failed to accept incoming connection") |
||||
} |
||||
|
||||
// Before use, a handshake must be performed on the incoming
|
||||
// net.Conn.
|
||||
_, chans, reqs, err := ssh.NewServerConn(nConn, config) |
||||
if err != nil { |
||||
panic("failed to handshake") |
||||
} |
||||
// The incoming Request channel must be serviced.
|
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
// Service the incoming Channel channel.
|
||||
for newChannel := range chans { |
||||
// Channels have a type, depending on the application level
|
||||
// protocol intended. In the case of a shell, the type is
|
||||
// "session" and ServerShell may be used to present a simple
|
||||
// terminal interface.
|
||||
if newChannel.ChannelType() != "session" { |
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") |
||||
continue |
||||
} |
||||
channel, requests, err := newChannel.Accept() |
||||
if err != nil { |
||||
panic("could not accept channel.") |
||||
} |
||||
|
||||
// Sessions have out-of-band requests such as "shell",
|
||||
// "pty-req" and "env". Here we handle only the
|
||||
// "shell" request.
|
||||
go func(in <-chan *ssh.Request) { |
||||
for req := range in { |
||||
ok := false |
||||
switch req.Type { |
||||
case "shell": |
||||
ok = true |
||||
if len(req.Payload) > 0 { |
||||
// We don't accept any
|
||||
// commands, only the
|
||||
// default shell.
|
||||
ok = false |
||||
} |
||||
} |
||||
req.Reply(ok, nil) |
||||
} |
||||
}(requests) |
||||
|
||||
term := terminal.NewTerminal(channel, "> ") |
||||
|
||||
go func() { |
||||
defer channel.Close() |
||||
for { |
||||
line, err := term.ReadLine() |
||||
if err != nil { |
||||
break |
||||
} |
||||
fmt.Println(line) |
||||
} |
||||
}() |
||||
} |
||||
} |
||||
|
||||
func ExampleDial() { |
||||
// An SSH client is represented with a ClientConn. Currently only
|
||||
// the "password" authentication method is supported.
|
||||
//
|
||||
// To authenticate with the remote server you must pass at least one
|
||||
// implementation of AuthMethod via the Auth field in ClientConfig.
|
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("yourpassword"), |
||||
}, |
||||
} |
||||
client, err := ssh.Dial("tcp", "yourserver.com:22", config) |
||||
if err != nil { |
||||
panic("Failed to dial: " + err.Error()) |
||||
} |
||||
|
||||
// Each ClientConn can support multiple interactive sessions,
|
||||
// represented by a Session.
|
||||
session, err := client.NewSession() |
||||
if err != nil { |
||||
panic("Failed to create session: " + err.Error()) |
||||
} |
||||
defer session.Close() |
||||
|
||||
// Once a Session is created, you can execute a single command on
|
||||
// the remote side using the Run method.
|
||||
var b bytes.Buffer |
||||
session.Stdout = &b |
||||
if err := session.Run("/usr/bin/whoami"); err != nil { |
||||
panic("Failed to run: " + err.Error()) |
||||
} |
||||
fmt.Println(b.String()) |
||||
} |
||||
|
||||
func ExampleClient_Listen() { |
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("password"), |
||||
}, |
||||
} |
||||
// Dial your ssh server.
|
||||
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||
if err != nil { |
||||
log.Fatalf("unable to connect: %s", err) |
||||
} |
||||
defer conn.Close() |
||||
|
||||
// Request the remote side to open port 8080 on all interfaces.
|
||||
l, err := conn.Listen("tcp", "0.0.0.0:8080") |
||||
if err != nil { |
||||
log.Fatalf("unable to register tcp forward: %v", err) |
||||
} |
||||
defer l.Close() |
||||
|
||||
// Serve HTTP with your SSH server acting as a reverse proxy.
|
||||
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||
fmt.Fprintf(resp, "Hello world!\n") |
||||
})) |
||||
} |
||||
|
||||
func ExampleSession_RequestPty() { |
||||
// Create client config
|
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("password"), |
||||
}, |
||||
} |
||||
// Connect to ssh server
|
||||
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||
if err != nil { |
||||
log.Fatalf("unable to connect: %s", err) |
||||
} |
||||
defer conn.Close() |
||||
// Create a session
|
||||
session, err := conn.NewSession() |
||||
if err != nil { |
||||
log.Fatalf("unable to create session: %s", err) |
||||
} |
||||
defer session.Close() |
||||
// Set up terminal modes
|
||||
modes := ssh.TerminalModes{ |
||||
ssh.ECHO: 0, // disable echoing
|
||||
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
||||
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
||||
} |
||||
// Request pseudo terminal
|
||||
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { |
||||
log.Fatalf("request for pseudo terminal failed: %s", err) |
||||
} |
||||
// Start remote shell
|
||||
if err := session.Shell(); err != nil { |
||||
log.Fatalf("failed to start shell: %s", err) |
||||
} |
||||
} |
@ -1,412 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"sync" |
||||
) |
||||
|
||||
// debugHandshake, if set, prints messages sent and received. Key
|
||||
// exchange messages are printed as if DH were used, so the debug
|
||||
// messages are wrong when using ECDH.
|
||||
const debugHandshake = false |
||||
|
||||
// keyingTransport is a packet based transport that supports key
|
||||
// changes. It need not be thread-safe. It should pass through
|
||||
// msgNewKeys in both directions.
|
||||
type keyingTransport interface { |
||||
packetConn |
||||
|
||||
// prepareKeyChange sets up a key change. The key change for a
|
||||
// direction will be effected if a msgNewKeys message is sent
|
||||
// or received.
|
||||
prepareKeyChange(*algorithms, *kexResult) error |
||||
|
||||
// getSessionID returns the session ID. prepareKeyChange must
|
||||
// have been called once.
|
||||
getSessionID() []byte |
||||
} |
||||
|
||||
// rekeyingTransport is the interface of handshakeTransport that we
|
||||
// (internally) expose to ClientConn and ServerConn.
|
||||
type rekeyingTransport interface { |
||||
packetConn |
||||
|
||||
// requestKeyChange asks the remote side to change keys. All
|
||||
// writes are blocked until the key change succeeds, which is
|
||||
// signaled by reading a msgNewKeys.
|
||||
requestKeyChange() error |
||||
|
||||
// getSessionID returns the session ID. This is only valid
|
||||
// after the first key change has completed.
|
||||
getSessionID() []byte |
||||
} |
||||
|
||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
||||
// and offers a thread-safe writePacket() interface.
|
||||
type handshakeTransport struct { |
||||
conn keyingTransport |
||||
config *Config |
||||
|
||||
serverVersion []byte |
||||
clientVersion []byte |
||||
|
||||
// hostKeys is non-empty if we are the server. In that case,
|
||||
// it contains all host keys that can be used to sign the
|
||||
// connection.
|
||||
hostKeys []Signer |
||||
|
||||
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
||||
// we accept these key types from the server as host key.
|
||||
hostKeyAlgorithms []string |
||||
|
||||
// On read error, incoming is closed, and readError is set.
|
||||
incoming chan []byte |
||||
readError error |
||||
|
||||
// data for host key checking
|
||||
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
dialAddress string |
||||
remoteAddr net.Addr |
||||
|
||||
readSinceKex uint64 |
||||
|
||||
// Protects the writing side of the connection
|
||||
mu sync.Mutex |
||||
cond *sync.Cond |
||||
sentInitPacket []byte |
||||
sentInitMsg *kexInitMsg |
||||
writtenSinceKex uint64 |
||||
writeError error |
||||
} |
||||
|
||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { |
||||
t := &handshakeTransport{ |
||||
conn: conn, |
||||
serverVersion: serverVersion, |
||||
clientVersion: clientVersion, |
||||
incoming: make(chan []byte, 16), |
||||
config: config, |
||||
} |
||||
t.cond = sync.NewCond(&t.mu) |
||||
return t |
||||
} |
||||
|
||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.dialAddress = dialAddr |
||||
t.remoteAddr = addr |
||||
t.hostKeyCallback = config.HostKeyCallback |
||||
if config.HostKeyAlgorithms != nil { |
||||
t.hostKeyAlgorithms = config.HostKeyAlgorithms |
||||
} else { |
||||
t.hostKeyAlgorithms = supportedHostKeyAlgos |
||||
} |
||||
go t.readLoop() |
||||
return t |
||||
} |
||||
|
||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.hostKeys = config.hostKeys |
||||
go t.readLoop() |
||||
return t |
||||
} |
||||
|
||||
func (t *handshakeTransport) getSessionID() []byte { |
||||
return t.conn.getSessionID() |
||||
} |
||||
|
||||
func (t *handshakeTransport) id() string { |
||||
if len(t.hostKeys) > 0 { |
||||
return "server" |
||||
} |
||||
return "client" |
||||
} |
||||
|
||||
func (t *handshakeTransport) readPacket() ([]byte, error) { |
||||
p, ok := <-t.incoming |
||||
if !ok { |
||||
return nil, t.readError |
||||
} |
||||
return p, nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) readLoop() { |
||||
for { |
||||
p, err := t.readOnePacket() |
||||
if err != nil { |
||||
t.readError = err |
||||
close(t.incoming) |
||||
break |
||||
} |
||||
if p[0] == msgIgnore || p[0] == msgDebug { |
||||
continue |
||||
} |
||||
t.incoming <- p |
||||
} |
||||
|
||||
// If we can't read, declare the writing part dead too.
|
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
if t.writeError == nil { |
||||
t.writeError = t.readError |
||||
} |
||||
t.cond.Broadcast() |
||||
} |
||||
|
||||
func (t *handshakeTransport) readOnePacket() ([]byte, error) { |
||||
if t.readSinceKex > t.config.RekeyThreshold { |
||||
if err := t.requestKeyChange(); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
p, err := t.conn.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
t.readSinceKex += uint64(len(p)) |
||||
if debugHandshake { |
||||
msg, err := decode(p) |
||||
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) |
||||
} |
||||
if p[0] != msgKexInit { |
||||
return p, nil |
||||
} |
||||
err = t.enterKeyExchange(p) |
||||
|
||||
t.mu.Lock() |
||||
if err != nil { |
||||
// drop connection
|
||||
t.conn.Close() |
||||
t.writeError = err |
||||
} |
||||
|
||||
if debugHandshake { |
||||
log.Printf("%s exited key exchange, err %v", t.id(), err) |
||||
} |
||||
|
||||
// Unblock writers.
|
||||
t.sentInitMsg = nil |
||||
t.sentInitPacket = nil |
||||
t.cond.Broadcast() |
||||
t.writtenSinceKex = 0 |
||||
t.mu.Unlock() |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
t.readSinceKex = 0 |
||||
return []byte{msgNewKeys}, nil |
||||
} |
||||
|
||||
// sendKexInit sends a key change message, and returns the message
|
||||
// that was sent. After initiating the key change, all writes will be
|
||||
// blocked until the change is done, and a failed key change will
|
||||
// close the underlying transport. This function is safe for
|
||||
// concurrent use by multiple goroutines.
|
||||
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
return t.sendKexInitLocked() |
||||
} |
||||
|
||||
func (t *handshakeTransport) requestKeyChange() error { |
||||
_, _, err := t.sendKexInit() |
||||
return err |
||||
} |
||||
|
||||
// sendKexInitLocked sends a key change message. t.mu must be locked
|
||||
// while this happens.
|
||||
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { |
||||
// kexInits may be sent either in response to the other side,
|
||||
// or because our side wants to initiate a key change, so we
|
||||
// may have already sent a kexInit. In that case, don't send a
|
||||
// second kexInit.
|
||||
if t.sentInitMsg != nil { |
||||
return t.sentInitMsg, t.sentInitPacket, nil |
||||
} |
||||
msg := &kexInitMsg{ |
||||
KexAlgos: t.config.KeyExchanges, |
||||
CiphersClientServer: t.config.Ciphers, |
||||
CiphersServerClient: t.config.Ciphers, |
||||
MACsClientServer: t.config.MACs, |
||||
MACsServerClient: t.config.MACs, |
||||
CompressionClientServer: supportedCompressions, |
||||
CompressionServerClient: supportedCompressions, |
||||
} |
||||
io.ReadFull(rand.Reader, msg.Cookie[:]) |
||||
|
||||
if len(t.hostKeys) > 0 { |
||||
for _, k := range t.hostKeys { |
||||
msg.ServerHostKeyAlgos = append( |
||||
msg.ServerHostKeyAlgos, k.PublicKey().Type()) |
||||
} |
||||
} else { |
||||
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms |
||||
} |
||||
packet := Marshal(msg) |
||||
|
||||
// writePacket destroys the contents, so save a copy.
|
||||
packetCopy := make([]byte, len(packet)) |
||||
copy(packetCopy, packet) |
||||
|
||||
if err := t.conn.writePacket(packetCopy); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
t.sentInitMsg = msg |
||||
t.sentInitPacket = packet |
||||
return msg, packet, nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) writePacket(p []byte) error { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
|
||||
if t.writtenSinceKex > t.config.RekeyThreshold { |
||||
t.sendKexInitLocked() |
||||
} |
||||
for t.sentInitMsg != nil && t.writeError == nil { |
||||
t.cond.Wait() |
||||
} |
||||
if t.writeError != nil { |
||||
return t.writeError |
||||
} |
||||
t.writtenSinceKex += uint64(len(p)) |
||||
|
||||
switch p[0] { |
||||
case msgKexInit: |
||||
return errors.New("ssh: only handshakeTransport can send kexInit") |
||||
case msgNewKeys: |
||||
return errors.New("ssh: only handshakeTransport can send newKeys") |
||||
default: |
||||
return t.conn.writePacket(p) |
||||
} |
||||
} |
||||
|
||||
func (t *handshakeTransport) Close() error { |
||||
return t.conn.Close() |
||||
} |
||||
|
||||
// enterKeyExchange runs the key exchange.
|
||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { |
||||
if debugHandshake { |
||||
log.Printf("%s entered key exchange", t.id()) |
||||
} |
||||
myInit, myInitPacket, err := t.sendKexInit() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
otherInit := &kexInitMsg{} |
||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil { |
||||
return err |
||||
} |
||||
|
||||
magics := handshakeMagics{ |
||||
clientVersion: t.clientVersion, |
||||
serverVersion: t.serverVersion, |
||||
clientKexInit: otherInitPacket, |
||||
serverKexInit: myInitPacket, |
||||
} |
||||
|
||||
clientInit := otherInit |
||||
serverInit := myInit |
||||
if len(t.hostKeys) == 0 { |
||||
clientInit = myInit |
||||
serverInit = otherInit |
||||
|
||||
magics.clientKexInit = myInitPacket |
||||
magics.serverKexInit = otherInitPacket |
||||
} |
||||
|
||||
algs, err := findAgreedAlgorithms(clientInit, serverInit) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// We don't send FirstKexFollows, but we handle receiving it.
|
||||
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { |
||||
// other side sent a kex message for the wrong algorithm,
|
||||
// which we have to ignore.
|
||||
if _, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
kex, ok := kexAlgoMap[algs.kex] |
||||
if !ok { |
||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) |
||||
} |
||||
|
||||
var result *kexResult |
||||
if len(t.hostKeys) > 0 { |
||||
result, err = t.server(kex, algs, &magics) |
||||
} else { |
||||
result, err = t.client(kex, algs, &magics) |
||||
} |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
t.conn.prepareKeyChange(algs, result) |
||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { |
||||
return err |
||||
} |
||||
if packet, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} else if packet[0] != msgNewKeys { |
||||
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
var hostKey Signer |
||||
for _, k := range t.hostKeys { |
||||
if algs.hostKey == k.PublicKey().Type() { |
||||
hostKey = k |
||||
} |
||||
} |
||||
|
||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) |
||||
return r, err |
||||
} |
||||
|
||||
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
result, err := kex.Client(t.conn, t.config.Rand, magics) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKey, err := ParsePublicKey(result.HostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err := verifyHostKeySignature(hostKey, result); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if t.hostKeyCallback != nil { |
||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return result, nil |
||||
} |
@ -1,415 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"runtime" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
type testChecker struct { |
||||
calls []string |
||||
} |
||||
|
||||
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||
if dialAddr == "bad" { |
||||
return fmt.Errorf("dialAddr is bad") |
||||
} |
||||
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { |
||||
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) |
||||
} |
||||
|
||||
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||
// a write.)
|
||||
func netPipe() (net.Conn, net.Conn, error) { |
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
defer listener.Close() |
||||
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
c2, err := listener.Accept() |
||||
if err != nil { |
||||
c1.Close() |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c1, c2, nil |
||||
} |
||||
|
||||
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
trC := newTransport(a, rand.Reader, true) |
||||
trS := newTransport(b, rand.Reader, false) |
||||
clientConf.SetDefaults() |
||||
|
||||
v := []byte("version") |
||||
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) |
||||
|
||||
serverConf := &ServerConfig{} |
||||
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
serverConf.SetDefaults() |
||||
server = newServerTransport(trS, v, v, serverConf) |
||||
|
||||
return client, server, nil |
||||
} |
||||
|
||||
func TestHandshakeBasic(t *testing.T) { |
||||
if runtime.GOOS == "plan9" { |
||||
t.Skip("see golang.org/issue/7237") |
||||
} |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
|
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
go func() { |
||||
// Client writes a bunch of stuff, and does a key
|
||||
// change in the middle. This should not confuse the
|
||||
// handshake in progress
|
||||
for i := 0; i < 10; i++ { |
||||
p := []byte{msgRequestSuccess, byte(i)} |
||||
if err := trC.writePacket(p); err != nil { |
||||
t.Fatalf("sendPacket: %v", err) |
||||
} |
||||
if i == 5 { |
||||
// halfway through, we request a key change.
|
||||
_, _, err := trC.sendKexInit() |
||||
if err != nil { |
||||
t.Fatalf("sendKexInit: %v", err) |
||||
} |
||||
} |
||||
} |
||||
trC.Close() |
||||
}() |
||||
|
||||
// Server checks that client messages come in cleanly
|
||||
i := 0 |
||||
for { |
||||
p, err := trS.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
if p[0] == msgNewKeys { |
||||
continue |
||||
} |
||||
want := []byte{msgRequestSuccess, byte(i)} |
||||
if bytes.Compare(p, want) != 0 { |
||||
t.Errorf("message %d: got %q, want %q", i, p, want) |
||||
} |
||||
i++ |
||||
} |
||||
if i != 10 { |
||||
t.Errorf("received %d messages, want 10.", i) |
||||
} |
||||
|
||||
// If all went well, we registered exactly 1 key change.
|
||||
if len(checker.calls) != 1 { |
||||
t.Fatalf("got %d host key checks, want 1", len(checker.calls)) |
||||
} |
||||
|
||||
pub := testSigners["ecdsa"].PublicKey() |
||||
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) |
||||
if want != checker.calls[0] { |
||||
t.Errorf("got %q want %q for host key check", checker.calls[0], want) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeError(t *testing.T) { |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
// send a packet
|
||||
packet := []byte{msgRequestSuccess, 42} |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// Now request a key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
// the key change will fail, and afterwards we can't write.
|
||||
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { |
||||
t.Errorf("writePacket after botched rekey succeeded.") |
||||
} |
||||
|
||||
readback, err := trS.readPacket() |
||||
if err != nil { |
||||
t.Fatalf("server closed too soon: %v", err) |
||||
} |
||||
if bytes.Compare(readback, packet) != 0 { |
||||
t.Errorf("got %q want %q", readback, packet) |
||||
} |
||||
readback, err = trS.readPacket() |
||||
if err == nil { |
||||
t.Errorf("got a message %q after failed key change", readback) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeTwice(t *testing.T) { |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
|
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
// send a packet
|
||||
packet := make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// Now request a key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
// Send another packet. Use a fresh one, since writePacket destroys.
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// 2nd key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
for i := 0; i < 5; i++ { |
||||
msg, err := trS.readPacket() |
||||
if err != nil { |
||||
t.Fatalf("server closed too soon: %v", err) |
||||
} |
||||
if msg[0] == msgNewKeys { |
||||
continue |
||||
} |
||||
|
||||
if bytes.Compare(msg, packet) != 0 { |
||||
t.Errorf("packet %d: got %q want %q", i, msg, packet) |
||||
} |
||||
} |
||||
if len(checker.calls) != 2 { |
||||
t.Errorf("got %d key changes, want 2", len(checker.calls)) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeAutoRekeyWrite(t *testing.T) { |
||||
checker := &testChecker{} |
||||
clientConf := &ClientConfig{HostKeyCallback: checker.Check} |
||||
clientConf.RekeyThreshold = 500 |
||||
trC, trS, err := handshakePair(clientConf, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
for i := 0; i < 5; i++ { |
||||
packet := make([]byte, 251) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
} |
||||
|
||||
j := 0 |
||||
for ; j < 5; j++ { |
||||
_, err := trS.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
|
||||
if j != 5 { |
||||
t.Errorf("got %d, want 5 messages", j) |
||||
} |
||||
|
||||
if len(checker.calls) != 2 { |
||||
t.Errorf("got %d key changes, wanted 2", len(checker.calls)) |
||||
} |
||||
} |
||||
|
||||
type syncChecker struct { |
||||
called chan int |
||||
} |
||||
|
||||
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||
t.called <- 1 |
||||
return nil |
||||
} |
||||
|
||||
func TestHandshakeAutoRekeyRead(t *testing.T) { |
||||
sync := &syncChecker{make(chan int, 2)} |
||||
clientConf := &ClientConfig{ |
||||
HostKeyCallback: sync.Check, |
||||
} |
||||
clientConf.RekeyThreshold = 500 |
||||
|
||||
trC, trS, err := handshakePair(clientConf, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
packet := make([]byte, 501) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trS.writePacket(packet); err != nil { |
||||
t.Fatalf("writePacket: %v", err) |
||||
} |
||||
// While we read out the packet, a key change will be
|
||||
// initiated.
|
||||
if _, err := trC.readPacket(); err != nil { |
||||
t.Fatalf("readPacket(client): %v", err) |
||||
} |
||||
|
||||
<-sync.called |
||||
} |
||||
|
||||
// errorKeyingTransport generates errors after a given number of
|
||||
// read/write operations.
|
||||
type errorKeyingTransport struct { |
||||
packetConn |
||||
readLeft, writeLeft int |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { |
||||
return nil |
||||
} |
||||
func (n *errorKeyingTransport) getSessionID() []byte { |
||||
return nil |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) writePacket(packet []byte) error { |
||||
if n.writeLeft == 0 { |
||||
n.Close() |
||||
return errors.New("barf") |
||||
} |
||||
|
||||
n.writeLeft-- |
||||
return n.packetConn.writePacket(packet) |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) readPacket() ([]byte, error) { |
||||
if n.readLeft == 0 { |
||||
n.Close() |
||||
return nil, errors.New("barf") |
||||
} |
||||
|
||||
n.readLeft-- |
||||
return n.packetConn.readPacket() |
||||
} |
||||
|
||||
func TestHandshakeErrorHandlingRead(t *testing.T) { |
||||
for i := 0; i < 20; i++ { |
||||
testHandshakeErrorHandlingN(t, i, -1) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeErrorHandlingWrite(t *testing.T) { |
||||
for i := 0; i < 20; i++ { |
||||
testHandshakeErrorHandlingN(t, -1, i) |
||||
} |
||||
} |
||||
|
||||
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
|
||||
// handshakeTransport deadlocks, the go runtime will detect it and
|
||||
// panic.
|
||||
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { |
||||
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) |
||||
|
||||
a, b := memPipe() |
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
key := testSigners["ecdsa"] |
||||
serverConf := Config{RekeyThreshold: minRekeyThreshold} |
||||
serverConf.SetDefaults() |
||||
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) |
||||
serverConn.hostKeys = []Signer{key} |
||||
go serverConn.readLoop() |
||||
|
||||
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} |
||||
clientConf.SetDefaults() |
||||
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) |
||||
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} |
||||
go clientConn.readLoop() |
||||
|
||||
var wg sync.WaitGroup |
||||
wg.Add(4) |
||||
|
||||
for _, hs := range []packetConn{serverConn, clientConn} { |
||||
go func(c packetConn) { |
||||
for { |
||||
err := c.writePacket(msg) |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
wg.Done() |
||||
}(hs) |
||||
go func(c packetConn) { |
||||
for { |
||||
_, err := c.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
wg.Done() |
||||
}(hs) |
||||
} |
||||
|
||||
wg.Wait() |
||||
} |
@ -1,526 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/subtle" |
||||
"crypto/rand" |
||||
"errors" |
||||
"io" |
||||
"math/big" |
||||
|
||||
"golang.org/x/crypto/curve25519" |
||||
) |
||||
|
||||
const ( |
||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" |
||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" |
||||
kexAlgoECDH256 = "ecdh-sha2-nistp256" |
||||
kexAlgoECDH384 = "ecdh-sha2-nistp384" |
||||
kexAlgoECDH521 = "ecdh-sha2-nistp521" |
||||
kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" |
||||
) |
||||
|
||||
// kexResult captures the outcome of a key exchange.
|
||||
type kexResult struct { |
||||
// Session hash. See also RFC 4253, section 8.
|
||||
H []byte |
||||
|
||||
// Shared secret. See also RFC 4253, section 8.
|
||||
K []byte |
||||
|
||||
// Host key as hashed into H.
|
||||
HostKey []byte |
||||
|
||||
// Signature of H.
|
||||
Signature []byte |
||||
|
||||
// A cryptographic hash function that matches the security
|
||||
// level of the key exchange algorithm. It is used for
|
||||
// calculating H, and for deriving keys from H and K.
|
||||
Hash crypto.Hash |
||||
|
||||
// The session ID, which is the first H computed. This is used
|
||||
// to signal data inside transport.
|
||||
SessionID []byte |
||||
} |
||||
|
||||
// handshakeMagics contains data that is always included in the
|
||||
// session hash.
|
||||
type handshakeMagics struct { |
||||
clientVersion, serverVersion []byte |
||||
clientKexInit, serverKexInit []byte |
||||
} |
||||
|
||||
func (m *handshakeMagics) write(w io.Writer) { |
||||
writeString(w, m.clientVersion) |
||||
writeString(w, m.serverVersion) |
||||
writeString(w, m.clientKexInit) |
||||
writeString(w, m.serverKexInit) |
||||
} |
||||
|
||||
// kexAlgorithm abstracts different key exchange algorithms.
|
||||
type kexAlgorithm interface { |
||||
// Server runs server-side key agreement, signing the result
|
||||
// with a hostkey.
|
||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) |
||||
|
||||
// Client runs the client-side key agreement. Caller is
|
||||
// responsible for verifying the host key signature.
|
||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) |
||||
} |
||||
|
||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||
type dhGroup struct { |
||||
g, p *big.Int |
||||
} |
||||
|
||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { |
||||
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 { |
||||
return nil, errors.New("ssh: DH parameter out of bounds") |
||||
} |
||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil |
||||
} |
||||
|
||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
hashFunc := crypto.SHA1 |
||||
|
||||
x, err := rand.Int(randSource, group.p) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
X := new(big.Int).Exp(group.g, x, group.p) |
||||
kexDHInit := kexDHInitMsg{ |
||||
X: X, |
||||
} |
||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexDHReply kexDHReplyMsg |
||||
if err = Unmarshal(packet, &kexDHReply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kInt, err := group.diffieHellman(kexDHReply.Y, x) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, kexDHReply.HostKey) |
||||
writeInt(h, X) |
||||
writeInt(h, kexDHReply.Y) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: kexDHReply.HostKey, |
||||
Signature: kexDHReply.Signature, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
hashFunc := crypto.SHA1 |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexDHInit kexDHInitMsg |
||||
if err = Unmarshal(packet, &kexDHInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
y, err := rand.Int(randSource, group.p) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
Y := new(big.Int).Exp(group.g, y, group.p) |
||||
kInt, err := group.diffieHellman(kexDHInit.X, y) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeInt(h, kexDHInit.X) |
||||
writeInt(h, Y) |
||||
|
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, randSource, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexDHReply := kexDHReplyMsg{ |
||||
HostKey: hostKeyBytes, |
||||
Y: Y, |
||||
Signature: sig, |
||||
} |
||||
packet = Marshal(&kexDHReply) |
||||
|
||||
err = c.writePacket(packet) |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
||||
// described in RFC 5656, section 4.
|
||||
type ecdh struct { |
||||
curve elliptic.Curve |
||||
} |
||||
|
||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexInit := kexECDHInitMsg{ |
||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), |
||||
} |
||||
|
||||
serialized := Marshal(&kexInit) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
// unmarshalECKey parses and checks an EC key.
|
||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { |
||||
x, y = elliptic.Unmarshal(curve, pubkey) |
||||
if x == nil { |
||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") |
||||
} |
||||
if !validateECPublicKey(curve, x, y) { |
||||
return nil, nil, errors.New("ssh: public key not on curve") |
||||
} |
||||
return x, y, nil |
||||
} |
||||
|
||||
// validateECPublicKey checks that the point is a valid public key for
|
||||
// the given curve. See [SEC1], 3.2.2
|
||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { |
||||
if x.Sign() == 0 && y.Sign() == 0 { |
||||
return false |
||||
} |
||||
|
||||
if x.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if y.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if !curve.IsOnCurve(x, y) { |
||||
return false |
||||
} |
||||
|
||||
// We don't check if N * PubKey == 0, since
|
||||
//
|
||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
||||
// (We don't foresee an implementation that supports non NIST
|
||||
// curves)
|
||||
//
|
||||
// - for ephemeral keys, we don't need to worry about small
|
||||
// subgroup attacks.
|
||||
return true |
||||
} |
||||
|
||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexECDHInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexECDHInit); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// We could cache this key across multiple users/multiple
|
||||
// connection attempts, but the benefit is small. OpenSSH
|
||||
// generates a new key for each incoming connection.
|
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexECDHInit.ClientPubKey) |
||||
writeString(h, serializedEphKey) |
||||
|
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: serializedEphKey, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
|
||||
serialized := Marshal(&reply) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: sig, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
var kexAlgoMap = map[string]kexAlgorithm{} |
||||
|
||||
func init() { |
||||
// This is the group called diffie-hellman-group1-sha1 in RFC
|
||||
// 4253 and Oakley Group 2 in RFC 2409.
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) |
||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
} |
||||
|
||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
||||
// 4253 and Oakley Group 14 in RFC 3526.
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) |
||||
|
||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
} |
||||
|
||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} |
||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} |
||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} |
||||
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} |
||||
} |
||||
|
||||
// curve25519sha256 implements the curve25519-sha256@libssh.org key
|
||||
// agreement protocol, as described in
|
||||
// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
|
||||
type curve25519sha256 struct{} |
||||
|
||||
type curve25519KeyPair struct { |
||||
priv [32]byte |
||||
pub [32]byte |
||||
} |
||||
|
||||
func (kp *curve25519KeyPair) generate(rand io.Reader) error { |
||||
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { |
||||
return err |
||||
} |
||||
curve25519.ScalarBaseMult(&kp.pub, &kp.priv) |
||||
return nil |
||||
} |
||||
|
||||
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
||||
// convenient to compare against in order to reject curve25519 points with the
|
||||
// wrong order.
|
||||
var curve25519Zeros [32]byte |
||||
|
||||
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
if len(reply.EphemeralPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var servPub, secret [32]byte |
||||
copy(servPub[:], reply.EphemeralPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &servPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kp.pub[:]) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
||||
|
||||
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
if len(kexInit.ClientPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var clientPub, secret [32]byte |
||||
copy(clientPub[:], kexInit.ClientPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &clientPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, kp.pub[:]) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: kp.pub[:], |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
if err := c.writePacket(Marshal(&reply)); err != nil { |
||||
return nil, err |
||||
} |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
@ -1,50 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Key exchange tests.
|
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"reflect" |
||||
"testing" |
||||
) |
||||
|
||||
func TestKexes(t *testing.T) { |
||||
type kexResultErr struct { |
||||
result *kexResult |
||||
err error |
||||
} |
||||
|
||||
for name, kex := range kexAlgoMap { |
||||
a, b := memPipe() |
||||
|
||||
s := make(chan kexResultErr, 1) |
||||
c := make(chan kexResultErr, 1) |
||||
var magics handshakeMagics |
||||
go func() { |
||||
r, e := kex.Client(a, rand.Reader, &magics) |
||||
a.Close() |
||||
c <- kexResultErr{r, e} |
||||
}() |
||||
go func() { |
||||
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) |
||||
b.Close() |
||||
s <- kexResultErr{r, e} |
||||
}() |
||||
|
||||
clientRes := <-c |
||||
serverRes := <-s |
||||
if clientRes.err != nil { |
||||
t.Errorf("client: %v", clientRes.err) |
||||
} |
||||
if serverRes.err != nil { |
||||
t.Errorf("server: %v", serverRes.err) |
||||
} |
||||
if !reflect.DeepEqual(clientRes.result, serverRes.result) { |
||||
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) |
||||
} |
||||
} |
||||
} |
@ -1,628 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rsa" |
||||
"crypto/x509" |
||||
"encoding/asn1" |
||||
"encoding/base64" |
||||
"encoding/pem" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
) |
||||
|
||||
// These constants represent the algorithm names for key types supported by this
|
||||
// package.
|
||||
const ( |
||||
KeyAlgoRSA = "ssh-rsa" |
||||
KeyAlgoDSA = "ssh-dss" |
||||
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" |
||||
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" |
||||
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" |
||||
) |
||||
|
||||
// parsePubKey parses a public key of the given algorithm.
|
||||
// Use ParsePublicKey for keys with prepended algorithm.
|
||||
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { |
||||
switch algo { |
||||
case KeyAlgoRSA: |
||||
return parseRSA(in) |
||||
case KeyAlgoDSA: |
||||
return parseDSA(in) |
||||
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: |
||||
return parseECDSA(in) |
||||
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||
cert, err := parseCert(in, certToPrivAlgo(algo)) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
return cert, nil, nil |
||||
} |
||||
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) |
||||
} |
||||
|
||||
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
|
||||
// (see sshd(8) manual page) once the options and key type fields have been
|
||||
// removed.
|
||||
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { |
||||
in = bytes.TrimSpace(in) |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
i = len(in) |
||||
} |
||||
base64Key := in[:i] |
||||
|
||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) |
||||
n, err := base64.StdEncoding.Decode(key, base64Key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
key = key[:n] |
||||
out, err = ParsePublicKey(key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
comment = string(bytes.TrimSpace(in[i:])) |
||||
return out, comment, nil |
||||
} |
||||
|
||||
// ParseAuthorizedKeys parses a public key from an authorized_keys
|
||||
// file used in OpenSSH according to the sshd(8) manual page.
|
||||
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { |
||||
for len(in) > 0 { |
||||
end := bytes.IndexByte(in, '\n') |
||||
if end != -1 { |
||||
rest = in[end+1:] |
||||
in = in[:end] |
||||
} else { |
||||
rest = nil |
||||
} |
||||
|
||||
end = bytes.IndexByte(in, '\r') |
||||
if end != -1 { |
||||
in = in[:end] |
||||
} |
||||
|
||||
in = bytes.TrimSpace(in) |
||||
if len(in) == 0 || in[0] == '#' { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
// No key type recognised. Maybe there's an options field at
|
||||
// the beginning.
|
||||
var b byte |
||||
inQuote := false |
||||
var candidateOptions []string |
||||
optionStart := 0 |
||||
for i, b = range in { |
||||
isEnd := !inQuote && (b == ' ' || b == '\t') |
||||
if (b == ',' && !inQuote) || isEnd { |
||||
if i-optionStart > 0 { |
||||
candidateOptions = append(candidateOptions, string(in[optionStart:i])) |
||||
} |
||||
optionStart = i + 1 |
||||
} |
||||
if isEnd { |
||||
break |
||||
} |
||||
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { |
||||
inQuote = !inQuote |
||||
} |
||||
} |
||||
for i < len(in) && (in[i] == ' ' || in[i] == '\t') { |
||||
i++ |
||||
} |
||||
if i == len(in) { |
||||
// Invalid line: unmatched quote
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
in = in[i:] |
||||
i = bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
options = candidateOptions |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
return nil, "", nil, nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
// ParsePublicKey parses an SSH public key formatted for use in
|
||||
// the SSH wire protocol according to RFC 4253, section 6.6.
|
||||
func ParsePublicKey(in []byte) (out PublicKey, err error) { |
||||
algo, in, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
var rest []byte |
||||
out, rest, err = parsePubKey(in, string(algo)) |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: trailing junk in public key") |
||||
} |
||||
|
||||
return out, err |
||||
} |
||||
|
||||
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
|
||||
// authorized_keys file. The return value ends with newline.
|
||||
func MarshalAuthorizedKey(key PublicKey) []byte { |
||||
b := &bytes.Buffer{} |
||||
b.WriteString(key.Type()) |
||||
b.WriteByte(' ') |
||||
e := base64.NewEncoder(base64.StdEncoding, b) |
||||
e.Write(key.Marshal()) |
||||
e.Close() |
||||
b.WriteByte('\n') |
||||
return b.Bytes() |
||||
} |
||||
|
||||
// PublicKey is an abstraction of different types of public keys.
|
||||
type PublicKey interface { |
||||
// Type returns the key's type, e.g. "ssh-rsa".
|
||||
Type() string |
||||
|
||||
// Marshal returns the serialized key data in SSH wire format,
|
||||
// with the name prefix.
|
||||
Marshal() []byte |
||||
|
||||
// Verify that sig is a signature on the given data using this
|
||||
// key. This function will hash the data appropriately first.
|
||||
Verify(data []byte, sig *Signature) error |
||||
} |
||||
|
||||
// A Signer can create signatures that verify against a public key.
|
||||
type Signer interface { |
||||
// PublicKey returns an associated PublicKey instance.
|
||||
PublicKey() PublicKey |
||||
|
||||
// Sign returns raw signature for the given data. This method
|
||||
// will apply the hash specified for the keytype to the data.
|
||||
Sign(rand io.Reader, data []byte) (*Signature, error) |
||||
} |
||||
|
||||
type rsaPublicKey rsa.PublicKey |
||||
|
||||
func (r *rsaPublicKey) Type() string { |
||||
return "ssh-rsa" |
||||
} |
||||
|
||||
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
|
||||
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
E *big.Int |
||||
N *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
if w.E.BitLen() > 24 { |
||||
return nil, nil, errors.New("ssh: exponent too large") |
||||
} |
||||
e := w.E.Int64() |
||||
if e < 3 || e&1 == 0 { |
||||
return nil, nil, errors.New("ssh: incorrect exponent") |
||||
} |
||||
|
||||
var key rsa.PublicKey |
||||
key.E = int(e) |
||||
key.N = w.N |
||||
return (*rsaPublicKey)(&key), w.Rest, nil |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Marshal() []byte { |
||||
e := new(big.Int).SetInt64(int64(r.E)) |
||||
wirekey := struct { |
||||
Name string |
||||
E *big.Int |
||||
N *big.Int |
||||
}{ |
||||
KeyAlgoRSA, |
||||
e, |
||||
r.N, |
||||
} |
||||
return Marshal(&wirekey) |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != r.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) |
||||
} |
||||
|
||||
type rsaPrivateKey struct { |
||||
*rsa.PrivateKey |
||||
} |
||||
|
||||
func (r *rsaPrivateKey) PublicKey() PublicKey { |
||||
return (*rsaPublicKey)(&r.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &Signature{ |
||||
Format: r.PublicKey().Type(), |
||||
Blob: blob, |
||||
}, nil |
||||
} |
||||
|
||||
type dsaPublicKey dsa.PublicKey |
||||
|
||||
func (r *dsaPublicKey) Type() string { |
||||
return "ssh-dss" |
||||
} |
||||
|
||||
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
|
||||
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
P, Q, G, Y *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := &dsaPublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: w.P, |
||||
Q: w.Q, |
||||
G: w.G, |
||||
}, |
||||
Y: w.Y, |
||||
} |
||||
return key, w.Rest, nil |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Marshal() []byte { |
||||
w := struct { |
||||
Name string |
||||
P, Q, G, Y *big.Int |
||||
}{ |
||||
k.Type(), |
||||
k.P, |
||||
k.Q, |
||||
k.G, |
||||
k.Y, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != k.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 4253, section 6.6,
|
||||
// The value for 'dss_signature_blob' is encoded as a string containing
|
||||
// r, followed by s (which are 160-bit integers, without lengths or
|
||||
// padding, unsigned, and in network byte order).
|
||||
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
|
||||
if len(sig.Blob) != 40 { |
||||
return errors.New("ssh: DSA signature parse error") |
||||
} |
||||
r := new(big.Int).SetBytes(sig.Blob[:20]) |
||||
s := new(big.Int).SetBytes(sig.Blob[20:]) |
||||
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
type dsaPrivateKey struct { |
||||
*dsa.PrivateKey |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) PublicKey() PublicKey { |
||||
return (*dsaPublicKey)(&k.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
r, s, err := dsa.Sign(rand, k.PrivateKey, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sig := make([]byte, 40) |
||||
rb := r.Bytes() |
||||
sb := s.Bytes() |
||||
|
||||
copy(sig[20-len(rb):20], rb) |
||||
copy(sig[40-len(sb):], sb) |
||||
|
||||
return &Signature{ |
||||
Format: k.PublicKey().Type(), |
||||
Blob: sig, |
||||
}, nil |
||||
} |
||||
|
||||
type ecdsaPublicKey ecdsa.PublicKey |
||||
|
||||
func (key *ecdsaPublicKey) Type() string { |
||||
return "ecdsa-sha2-" + key.nistID() |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) nistID() string { |
||||
switch key.Params().BitSize { |
||||
case 256: |
||||
return "nistp256" |
||||
case 384: |
||||
return "nistp384" |
||||
case 521: |
||||
return "nistp521" |
||||
} |
||||
panic("ssh: unsupported ecdsa key size") |
||||
} |
||||
|
||||
func supportedEllipticCurve(curve elliptic.Curve) bool { |
||||
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() |
||||
} |
||||
|
||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
||||
// 5656, section 6.2.1
|
||||
func ecHash(curve elliptic.Curve) crypto.Hash { |
||||
bitSize := curve.Params().BitSize |
||||
switch { |
||||
case bitSize <= 256: |
||||
return crypto.SHA256 |
||||
case bitSize <= 384: |
||||
return crypto.SHA384 |
||||
} |
||||
return crypto.SHA512 |
||||
} |
||||
|
||||
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
|
||||
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
Curve string |
||||
KeyBytes []byte |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := new(ecdsa.PublicKey) |
||||
|
||||
switch w.Curve { |
||||
case "nistp256": |
||||
key.Curve = elliptic.P256() |
||||
case "nistp384": |
||||
key.Curve = elliptic.P384() |
||||
case "nistp521": |
||||
key.Curve = elliptic.P521() |
||||
default: |
||||
return nil, nil, errors.New("ssh: unsupported curve") |
||||
} |
||||
|
||||
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) |
||||
if key.X == nil || key.Y == nil { |
||||
return nil, nil, errors.New("ssh: invalid curve point") |
||||
} |
||||
return (*ecdsaPublicKey)(key), w.Rest, nil |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Marshal() []byte { |
||||
// See RFC 5656, section 3.1.
|
||||
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) |
||||
w := struct { |
||||
Name string |
||||
ID string |
||||
Key []byte |
||||
}{ |
||||
key.Type(), |
||||
key.nistID(), |
||||
keyBytes, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != key.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) |
||||
} |
||||
|
||||
h := ecHash(key.Curve).New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 5656, section 3.1.2,
|
||||
// The ecdsa_signature_blob value has the following specific encoding:
|
||||
// mpint r
|
||||
// mpint s
|
||||
var ecSig struct { |
||||
R *big.Int |
||||
S *big.Int |
||||
} |
||||
|
||||
if err := Unmarshal(sig.Blob, &ecSig); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
type ecdsaPrivateKey struct { |
||||
*ecdsa.PrivateKey |
||||
} |
||||
|
||||
func (k *ecdsaPrivateKey) PublicKey() PublicKey { |
||||
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := ecHash(k.PrivateKey.PublicKey.Curve).New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sig := make([]byte, intLength(r)+intLength(s)) |
||||
rest := marshalInt(sig, r) |
||||
marshalInt(rest, s) |
||||
return &Signature{ |
||||
Format: k.PublicKey().Type(), |
||||
Blob: sig, |
||||
}, nil |
||||
} |
||||
|
||||
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
|
||||
// returns a corresponding Signer instance. EC keys should use P256,
|
||||
// P384 or P521.
|
||||
func NewSignerFromKey(k interface{}) (Signer, error) { |
||||
var sshKey Signer |
||||
switch t := k.(type) { |
||||
case *rsa.PrivateKey: |
||||
sshKey = &rsaPrivateKey{t} |
||||
case *dsa.PrivateKey: |
||||
sshKey = &dsaPrivateKey{t} |
||||
case *ecdsa.PrivateKey: |
||||
if !supportedEllipticCurve(t.Curve) { |
||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||
} |
||||
|
||||
sshKey = &ecdsaPrivateKey{t} |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||
} |
||||
return sshKey, nil |
||||
} |
||||
|
||||
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
|
||||
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
|
||||
func NewPublicKey(k interface{}) (PublicKey, error) { |
||||
var sshKey PublicKey |
||||
switch t := k.(type) { |
||||
case *rsa.PublicKey: |
||||
sshKey = (*rsaPublicKey)(t) |
||||
case *ecdsa.PublicKey: |
||||
if !supportedEllipticCurve(t.Curve) { |
||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||
} |
||||
sshKey = (*ecdsaPublicKey)(t) |
||||
case *dsa.PublicKey: |
||||
sshKey = (*dsaPublicKey)(t) |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||
} |
||||
return sshKey, nil |
||||
} |
||||
|
||||
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
|
||||
// the same keys as ParseRawPrivateKey.
|
||||
func ParsePrivateKey(pemBytes []byte) (Signer, error) { |
||||
key, err := ParseRawPrivateKey(pemBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return NewSignerFromKey(key) |
||||
} |
||||
|
||||
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
|
||||
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
|
||||
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { |
||||
block, _ := pem.Decode(pemBytes) |
||||
if block == nil { |
||||
return nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
switch block.Type { |
||||
case "RSA PRIVATE KEY": |
||||
return x509.ParsePKCS1PrivateKey(block.Bytes) |
||||
case "EC PRIVATE KEY": |
||||
return x509.ParseECPrivateKey(block.Bytes) |
||||
case "DSA PRIVATE KEY": |
||||
return ParseDSAPrivateKey(block.Bytes) |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) |
||||
} |
||||
} |
||||
|
||||
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
|
||||
// specified by the OpenSSL DSA man page.
|
||||
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { |
||||
var k struct { |
||||
Version int |
||||
P *big.Int |
||||
Q *big.Int |
||||
G *big.Int |
||||
Priv *big.Int |
||||
Pub *big.Int |
||||
} |
||||
rest, err := asn1.Unmarshal(der, &k) |
||||
if err != nil { |
||||
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) |
||||
} |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: garbage after DSA key") |
||||
} |
||||
|
||||
return &dsa.PrivateKey{ |
||||
PublicKey: dsa.PublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: k.P, |
||||
Q: k.Q, |
||||
G: k.G, |
||||
}, |
||||
Y: k.Priv, |
||||
}, |
||||
X: k.Pub, |
||||
}, nil |
||||
} |
@ -1,306 +0,0 @@ |
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"encoding/base64" |
||||
"fmt" |
||||
"reflect" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh/testdata" |
||||
) |
||||
|
||||
func rawKey(pub PublicKey) interface{} { |
||||
switch k := pub.(type) { |
||||
case *rsaPublicKey: |
||||
return (*rsa.PublicKey)(k) |
||||
case *dsaPublicKey: |
||||
return (*dsa.PublicKey)(k) |
||||
case *ecdsaPublicKey: |
||||
return (*ecdsa.PublicKey)(k) |
||||
case *Certificate: |
||||
return k |
||||
} |
||||
panic("unknown key type") |
||||
} |
||||
|
||||
func TestKeyMarshalParse(t *testing.T) { |
||||
for _, priv := range testSigners { |
||||
pub := priv.PublicKey() |
||||
roundtrip, err := ParsePublicKey(pub.Marshal()) |
||||
if err != nil { |
||||
t.Errorf("ParsePublicKey(%T): %v", pub, err) |
||||
} |
||||
|
||||
k1 := rawKey(pub) |
||||
k2 := rawKey(roundtrip) |
||||
|
||||
if !reflect.DeepEqual(k1, k2) { |
||||
t.Errorf("got %#v in roundtrip, want %#v", k2, k1) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestUnsupportedCurves(t *testing.T) { |
||||
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) |
||||
if err != nil { |
||||
t.Fatalf("GenerateKey: %v", err) |
||||
} |
||||
|
||||
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err) |
||||
} |
||||
|
||||
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestNewPublicKey(t *testing.T) { |
||||
for _, k := range testSigners { |
||||
raw := rawKey(k.PublicKey()) |
||||
// Skip certificates, as NewPublicKey does not support them.
|
||||
if _, ok := raw.(*Certificate); ok { |
||||
continue |
||||
} |
||||
pub, err := NewPublicKey(raw) |
||||
if err != nil { |
||||
t.Errorf("NewPublicKey(%#v): %v", raw, err) |
||||
} |
||||
if !reflect.DeepEqual(k.PublicKey(), pub) { |
||||
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestKeySignVerify(t *testing.T) { |
||||
for _, priv := range testSigners { |
||||
pub := priv.PublicKey() |
||||
|
||||
data := []byte("sign me") |
||||
sig, err := priv.Sign(rand.Reader, data) |
||||
if err != nil { |
||||
t.Fatalf("Sign(%T): %v", priv, err) |
||||
} |
||||
|
||||
if err := pub.Verify(data, sig); err != nil { |
||||
t.Errorf("publicKey.Verify(%T): %v", priv, err) |
||||
} |
||||
sig.Blob[5]++ |
||||
if err := pub.Verify(data, sig); err == nil { |
||||
t.Errorf("publicKey.Verify on broken sig did not fail") |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestParseRSAPrivateKey(t *testing.T) { |
||||
key := testPrivateKeys["rsa"] |
||||
|
||||
rsa, ok := key.(*rsa.PrivateKey) |
||||
if !ok { |
||||
t.Fatalf("got %T, want *rsa.PrivateKey", rsa) |
||||
} |
||||
|
||||
if err := rsa.Validate(); err != nil { |
||||
t.Errorf("Validate: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestParseECPrivateKey(t *testing.T) { |
||||
key := testPrivateKeys["ecdsa"] |
||||
|
||||
ecKey, ok := key.(*ecdsa.PrivateKey) |
||||
if !ok { |
||||
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) |
||||
} |
||||
|
||||
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { |
||||
t.Fatalf("public key does not validate.") |
||||
} |
||||
} |
||||
|
||||
func TestParseDSA(t *testing.T) { |
||||
// We actually exercise the ParsePrivateKey codepath here, as opposed to
|
||||
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
|
||||
// uses.
|
||||
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) |
||||
if err != nil { |
||||
t.Fatalf("ParsePrivateKey returned error: %s", err) |
||||
} |
||||
|
||||
data := []byte("sign me") |
||||
sig, err := s.Sign(rand.Reader, data) |
||||
if err != nil { |
||||
t.Fatalf("dsa.Sign: %v", err) |
||||
} |
||||
|
||||
if err := s.PublicKey().Verify(data, sig); err != nil { |
||||
t.Errorf("Verify failed: %v", err) |
||||
} |
||||
} |
||||
|
||||
// Tests for authorized_keys parsing.
|
||||
|
||||
// getTestKey returns a public key, and its base64 encoding.
|
||||
func getTestKey() (PublicKey, string) { |
||||
k := testPublicKeys["rsa"] |
||||
|
||||
b := &bytes.Buffer{} |
||||
e := base64.NewEncoder(base64.StdEncoding, b) |
||||
e.Write(k.Marshal()) |
||||
e.Close() |
||||
|
||||
return k, b.String() |
||||
} |
||||
|
||||
func TestMarshalParsePublicKey(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) |
||||
|
||||
authKeys := MarshalAuthorizedKey(pub) |
||||
actualFields := strings.Fields(string(authKeys)) |
||||
if len(actualFields) == 0 { |
||||
t.Fatalf("failed authKeys: %v", authKeys) |
||||
} |
||||
|
||||
// drop the comment
|
||||
expectedFields := strings.Fields(line)[0:2] |
||||
|
||||
if !reflect.DeepEqual(actualFields, expectedFields) { |
||||
t.Errorf("got %v, expected %v", actualFields, expectedFields) |
||||
} |
||||
|
||||
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) |
||||
if err != nil { |
||||
t.Fatalf("cannot parse %v: %v", line, err) |
||||
} |
||||
if !reflect.DeepEqual(actPub, pub) { |
||||
t.Errorf("got %v, expected %v", actPub, pub) |
||||
} |
||||
} |
||||
|
||||
type authResult struct { |
||||
pubKey PublicKey |
||||
options []string |
||||
comments string |
||||
rest string |
||||
ok bool |
||||
} |
||||
|
||||
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { |
||||
rest := authKeys |
||||
var values []authResult |
||||
for len(rest) > 0 { |
||||
var r authResult |
||||
var err error |
||||
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) |
||||
r.ok = (err == nil) |
||||
t.Log(err) |
||||
r.rest = string(rest) |
||||
values = append(values, r) |
||||
} |
||||
|
||||
if !reflect.DeepEqual(values, expected) { |
||||
t.Errorf("got %#v, expected %#v", values, expected) |
||||
} |
||||
} |
||||
|
||||
func TestAuthorizedKeyBasic(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
line := "ssh-rsa " + pubSerialized + " user@host" |
||||
testAuthorizedKeys(t, []byte(line), |
||||
[]authResult{ |
||||
{pub, nil, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuth(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithOptions := []string{ |
||||
`# comments to ignore before any keys...`, |
||||
``, |
||||
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, |
||||
`# comments to ignore, along with a blank line`, |
||||
``, |
||||
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, |
||||
``, |
||||
`# more comments, plus a invalid entry`, |
||||
`ssh-rsa data-that-will-not-parse user@host3`, |
||||
} |
||||
for _, eol := range []string{"\n", "\r\n"} { |
||||
authOptions := strings.Join(authWithOptions, eol) |
||||
rest2 := strings.Join(authWithOptions[3:], eol) |
||||
rest3 := strings.Join(authWithOptions[6:], eol) |
||||
testAuthorizedKeys(t, []byte(authOptions), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, |
||||
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, |
||||
{nil, nil, "", "", false}, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestAuthWithQuotedSpaceInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithQuotedCommaInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithQuotedQuoteInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) |
||||
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
|
||||
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ |
||||
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithInvalidSpace(t *testing.T) { |
||||
_, pubSerialized := getTestKey() |
||||
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||
#more to follow but still no valid keys`) |
||||
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ |
||||
{nil, nil, "", "", false}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithMissingQuote(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) |
||||
|
||||
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestInvalidEntry(t *testing.T) { |
||||
authInvalid := []byte(`ssh-rsa`) |
||||
_, _, _, _, err := ParseAuthorizedKey(authInvalid) |
||||
if err == nil { |
||||
t.Errorf("got valid entry for %q", authInvalid) |
||||
} |
||||
} |
@ -1,57 +0,0 @@ |
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Message authentication support
|
||||
|
||||
import ( |
||||
"crypto/hmac" |
||||
"crypto/sha1" |
||||
"crypto/sha256" |
||||
"hash" |
||||
) |
||||
|
||||
type macMode struct { |
||||
keySize int |
||||
new func(key []byte) hash.Hash |
||||
} |
||||
|
||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||
// a given size.
|
||||
type truncatingMAC struct { |
||||
length int |
||||
hmac hash.Hash |
||||
} |
||||
|
||||
func (t truncatingMAC) Write(data []byte) (int, error) { |
||||
return t.hmac.Write(data) |
||||
} |
||||
|
||||
func (t truncatingMAC) Sum(in []byte) []byte { |
||||
out := t.hmac.Sum(in) |
||||
return out[:len(in)+t.length] |
||||
} |
||||
|
||||
func (t truncatingMAC) Reset() { |
||||
t.hmac.Reset() |
||||
} |
||||
|
||||
func (t truncatingMAC) Size() int { |
||||
return t.length |
||||
} |
||||
|
||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } |
||||
|
||||
var macModes = map[string]*macMode{ |
||||
"hmac-sha2-256": {32, func(key []byte) hash.Hash { |
||||
return hmac.New(sha256.New, key) |
||||
}}, |
||||
"hmac-sha1": {20, func(key []byte) hash.Hash { |
||||
return hmac.New(sha1.New, key) |
||||
}}, |
||||
"hmac-sha1-96": {20, func(key []byte) hash.Hash { |
||||
return truncatingMAC{12, hmac.New(sha1.New, key)} |
||||
}}, |
||||
} |
@ -1,110 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
// An in-memory packetConn. It is safe to call Close and writePacket
|
||||
// from different goroutines.
|
||||
type memTransport struct { |
||||
eof bool |
||||
pending [][]byte |
||||
write *memTransport |
||||
sync.Mutex |
||||
*sync.Cond |
||||
} |
||||
|
||||
func (t *memTransport) readPacket() ([]byte, error) { |
||||
t.Lock() |
||||
defer t.Unlock() |
||||
for { |
||||
if len(t.pending) > 0 { |
||||
r := t.pending[0] |
||||
t.pending = t.pending[1:] |
||||
return r, nil |
||||
} |
||||
if t.eof { |
||||
return nil, io.EOF |
||||
} |
||||
t.Cond.Wait() |
||||
} |
||||
} |
||||
|
||||
func (t *memTransport) closeSelf() error { |
||||
t.Lock() |
||||
defer t.Unlock() |
||||
if t.eof { |
||||
return io.EOF |
||||
} |
||||
t.eof = true |
||||
t.Cond.Broadcast() |
||||
return nil |
||||
} |
||||
|
||||
func (t *memTransport) Close() error { |
||||
err := t.write.closeSelf() |
||||
t.closeSelf() |
||||
return err |
||||
} |
||||
|
||||
func (t *memTransport) writePacket(p []byte) error { |
||||
t.write.Lock() |
||||
defer t.write.Unlock() |
||||
if t.write.eof { |
||||
return io.EOF |
||||
} |
||||
c := make([]byte, len(p)) |
||||
copy(c, p) |
||||
t.write.pending = append(t.write.pending, c) |
||||
t.write.Cond.Signal() |
||||
return nil |
||||
} |
||||
|
||||
func memPipe() (a, b packetConn) { |
||||
t1 := memTransport{} |
||||
t2 := memTransport{} |
||||
t1.write = &t2 |
||||
t2.write = &t1 |
||||
t1.Cond = sync.NewCond(&t1.Mutex) |
||||
t2.Cond = sync.NewCond(&t2.Mutex) |
||||
return &t1, &t2 |
||||
} |
||||
|
||||
func TestMemPipe(t *testing.T) { |
||||
a, b := memPipe() |
||||
if err := a.writePacket([]byte{42}); err != nil { |
||||
t.Fatalf("writePacket: %v", err) |
||||
} |
||||
if err := a.Close(); err != nil { |
||||
t.Fatal("Close: ", err) |
||||
} |
||||
p, err := b.readPacket() |
||||
if err != nil { |
||||
t.Fatal("readPacket: ", err) |
||||
} |
||||
if len(p) != 1 || p[0] != 42 { |
||||
t.Fatalf("got %v, want {42}", p) |
||||
} |
||||
p, err = b.readPacket() |
||||
if err != io.EOF { |
||||
t.Fatalf("got %v, %v, want EOF", p, err) |
||||
} |
||||
} |
||||
|
||||
func TestDoubleClose(t *testing.T) { |
||||
a, _ := memPipe() |
||||
err := a.Close() |
||||
if err != nil { |
||||
t.Errorf("Close: %v", err) |
||||
} |
||||
err = a.Close() |
||||
if err != io.EOF { |
||||
t.Errorf("expect EOF on double close.") |
||||
} |
||||
} |
@ -1,725 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
"reflect" |
||||
"strconv" |
||||
) |
||||
|
||||
// These are SSH message type numbers. They are scattered around several
|
||||
// documents but many were taken from [SSH-PARAMETERS].
|
||||
const ( |
||||
msgIgnore = 2 |
||||
msgUnimplemented = 3 |
||||
msgDebug = 4 |
||||
msgNewKeys = 21 |
||||
|
||||
// Standard authentication messages
|
||||
msgUserAuthSuccess = 52 |
||||
msgUserAuthBanner = 53 |
||||
) |
||||
|
||||
// SSH messages:
|
||||
//
|
||||
// These structures mirror the wire format of the corresponding SSH messages.
|
||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
||||
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
||||
|
||||
// See RFC 4253, section 11.1.
|
||||
const msgDisconnect = 1 |
||||
|
||||
// disconnectMsg is the message that signals a disconnect. It is also
|
||||
// the error type returned from mux.Wait()
|
||||
type disconnectMsg struct { |
||||
Reason uint32 `sshtype:"1"` |
||||
Message string |
||||
Language string |
||||
} |
||||
|
||||
func (d *disconnectMsg) Error() string { |
||||
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) |
||||
} |
||||
|
||||
// See RFC 4253, section 7.1.
|
||||
const msgKexInit = 20 |
||||
|
||||
type kexInitMsg struct { |
||||
Cookie [16]byte `sshtype:"20"` |
||||
KexAlgos []string |
||||
ServerHostKeyAlgos []string |
||||
CiphersClientServer []string |
||||
CiphersServerClient []string |
||||
MACsClientServer []string |
||||
MACsServerClient []string |
||||
CompressionClientServer []string |
||||
CompressionServerClient []string |
||||
LanguagesClientServer []string |
||||
LanguagesServerClient []string |
||||
FirstKexFollows bool |
||||
Reserved uint32 |
||||
} |
||||
|
||||
// See RFC 4253, section 8.
|
||||
|
||||
// Diffie-Helman
|
||||
const msgKexDHInit = 30 |
||||
|
||||
type kexDHInitMsg struct { |
||||
X *big.Int `sshtype:"30"` |
||||
} |
||||
|
||||
const msgKexECDHInit = 30 |
||||
|
||||
type kexECDHInitMsg struct { |
||||
ClientPubKey []byte `sshtype:"30"` |
||||
} |
||||
|
||||
const msgKexECDHReply = 31 |
||||
|
||||
type kexECDHReplyMsg struct { |
||||
HostKey []byte `sshtype:"31"` |
||||
EphemeralPubKey []byte |
||||
Signature []byte |
||||
} |
||||
|
||||
const msgKexDHReply = 31 |
||||
|
||||
type kexDHReplyMsg struct { |
||||
HostKey []byte `sshtype:"31"` |
||||
Y *big.Int |
||||
Signature []byte |
||||
} |
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceRequest = 5 |
||||
|
||||
type serviceRequestMsg struct { |
||||
Service string `sshtype:"5"` |
||||
} |
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceAccept = 6 |
||||
|
||||
type serviceAcceptMsg struct { |
||||
Service string `sshtype:"6"` |
||||
} |
||||
|
||||
// See RFC 4252, section 5.
|
||||
const msgUserAuthRequest = 50 |
||||
|
||||
type userAuthRequestMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Payload []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4252, section 5.1
|
||||
const msgUserAuthFailure = 51 |
||||
|
||||
type userAuthFailureMsg struct { |
||||
Methods []string `sshtype:"51"` |
||||
PartialSuccess bool |
||||
} |
||||
|
||||
// See RFC 4256, section 3.2
|
||||
const msgUserAuthInfoRequest = 60 |
||||
const msgUserAuthInfoResponse = 61 |
||||
|
||||
type userAuthInfoRequestMsg struct { |
||||
User string `sshtype:"60"` |
||||
Instruction string |
||||
DeprecatedLanguage string |
||||
NumPrompts uint32 |
||||
Prompts []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpen = 90 |
||||
|
||||
type channelOpenMsg struct { |
||||
ChanType string `sshtype:"90"` |
||||
PeersId uint32 |
||||
PeersWindow uint32 |
||||
MaxPacketSize uint32 |
||||
TypeSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
const msgChannelExtendedData = 95 |
||||
const msgChannelData = 94 |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenConfirm = 91 |
||||
|
||||
type channelOpenConfirmMsg struct { |
||||
PeersId uint32 `sshtype:"91"` |
||||
MyId uint32 |
||||
MyWindow uint32 |
||||
MaxPacketSize uint32 |
||||
TypeSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenFailure = 92 |
||||
|
||||
type channelOpenFailureMsg struct { |
||||
PeersId uint32 `sshtype:"92"` |
||||
Reason RejectionReason |
||||
Message string |
||||
Language string |
||||
} |
||||
|
||||
const msgChannelRequest = 98 |
||||
|
||||
type channelRequestMsg struct { |
||||
PeersId uint32 `sshtype:"98"` |
||||
Request string |
||||
WantReply bool |
||||
RequestSpecificData []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelSuccess = 99 |
||||
|
||||
type channelRequestSuccessMsg struct { |
||||
PeersId uint32 `sshtype:"99"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelFailure = 100 |
||||
|
||||
type channelRequestFailureMsg struct { |
||||
PeersId uint32 `sshtype:"100"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelClose = 97 |
||||
|
||||
type channelCloseMsg struct { |
||||
PeersId uint32 `sshtype:"97"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelEOF = 96 |
||||
|
||||
type channelEOFMsg struct { |
||||
PeersId uint32 `sshtype:"96"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgGlobalRequest = 80 |
||||
|
||||
type globalRequestMsg struct { |
||||
Type string `sshtype:"80"` |
||||
WantReply bool |
||||
Data []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestSuccess = 81 |
||||
|
||||
type globalRequestSuccessMsg struct { |
||||
Data []byte `ssh:"rest" sshtype:"81"` |
||||
} |
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestFailure = 82 |
||||
|
||||
type globalRequestFailureMsg struct { |
||||
Data []byte `ssh:"rest" sshtype:"82"` |
||||
} |
||||
|
||||
// See RFC 4254, section 5.2
|
||||
const msgChannelWindowAdjust = 93 |
||||
|
||||
type windowAdjustMsg struct { |
||||
PeersId uint32 `sshtype:"93"` |
||||
AdditionalBytes uint32 |
||||
} |
||||
|
||||
// See RFC 4252, section 7
|
||||
const msgUserAuthPubKeyOk = 60 |
||||
|
||||
type userAuthPubKeyOkMsg struct { |
||||
Algo string `sshtype:"60"` |
||||
PubKey []byte |
||||
} |
||||
|
||||
// typeTag returns the type byte for the given type. The type should
|
||||
// be struct.
|
||||
func typeTag(structType reflect.Type) byte { |
||||
var tag byte |
||||
var tagStr string |
||||
tagStr = structType.Field(0).Tag.Get("sshtype") |
||||
i, err := strconv.Atoi(tagStr) |
||||
if err == nil { |
||||
tag = byte(i) |
||||
} |
||||
return tag |
||||
} |
||||
|
||||
func fieldError(t reflect.Type, field int, problem string) error { |
||||
if problem != "" { |
||||
problem = ": " + problem |
||||
} |
||||
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) |
||||
} |
||||
|
||||
var errShortRead = errors.New("ssh: short read") |
||||
|
||||
// Unmarshal parses data in SSH wire format into a structure. The out
|
||||
// argument should be a pointer to struct. If the first member of the
|
||||
// struct has the "sshtype" tag set to a number in decimal, the packet
|
||||
// must start that number. In case of error, Unmarshal returns a
|
||||
// ParseError or UnexpectedMessageError.
|
||||
func Unmarshal(data []byte, out interface{}) error { |
||||
v := reflect.ValueOf(out).Elem() |
||||
structType := v.Type() |
||||
expectedType := typeTag(structType) |
||||
if len(data) == 0 { |
||||
return parseError(expectedType) |
||||
} |
||||
if expectedType > 0 { |
||||
if data[0] != expectedType { |
||||
return unexpectedMessageError(expectedType, data[0]) |
||||
} |
||||
data = data[1:] |
||||
} |
||||
|
||||
var ok bool |
||||
for i := 0; i < v.NumField(); i++ { |
||||
field := v.Field(i) |
||||
t := field.Type() |
||||
switch t.Kind() { |
||||
case reflect.Bool: |
||||
if len(data) < 1 { |
||||
return errShortRead |
||||
} |
||||
field.SetBool(data[0] != 0) |
||||
data = data[1:] |
||||
case reflect.Array: |
||||
if t.Elem().Kind() != reflect.Uint8 { |
||||
return fieldError(structType, i, "array of unsupported type") |
||||
} |
||||
if len(data) < t.Len() { |
||||
return errShortRead |
||||
} |
||||
for j, n := 0, t.Len(); j < n; j++ { |
||||
field.Index(j).Set(reflect.ValueOf(data[j])) |
||||
} |
||||
data = data[t.Len():] |
||||
case reflect.Uint64: |
||||
var u64 uint64 |
||||
if u64, data, ok = parseUint64(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(u64) |
||||
case reflect.Uint32: |
||||
var u32 uint32 |
||||
if u32, data, ok = parseUint32(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(uint64(u32)) |
||||
case reflect.Uint8: |
||||
if len(data) < 1 { |
||||
return errShortRead |
||||
} |
||||
field.SetUint(uint64(data[0])) |
||||
data = data[1:] |
||||
case reflect.String: |
||||
var s []byte |
||||
if s, data, ok = parseString(data); !ok { |
||||
return fieldError(structType, i, "") |
||||
} |
||||
field.SetString(string(s)) |
||||
case reflect.Slice: |
||||
switch t.Elem().Kind() { |
||||
case reflect.Uint8: |
||||
if structType.Field(i).Tag.Get("ssh") == "rest" { |
||||
field.Set(reflect.ValueOf(data)) |
||||
data = nil |
||||
} else { |
||||
var s []byte |
||||
if s, data, ok = parseString(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(s)) |
||||
} |
||||
case reflect.String: |
||||
var nl []string |
||||
if nl, data, ok = parseNameList(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(nl)) |
||||
default: |
||||
return fieldError(structType, i, "slice of unsupported type") |
||||
} |
||||
case reflect.Ptr: |
||||
if t == bigIntType { |
||||
var n *big.Int |
||||
if n, data, ok = parseInt(data); !ok { |
||||
return errShortRead |
||||
} |
||||
field.Set(reflect.ValueOf(n)) |
||||
} else { |
||||
return fieldError(structType, i, "pointer to unsupported type") |
||||
} |
||||
default: |
||||
return fieldError(structType, i, "unsupported type") |
||||
} |
||||
} |
||||
|
||||
if len(data) != 0 { |
||||
return parseError(expectedType) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Marshal serializes the message in msg to SSH wire format. The msg
|
||||
// argument should be a struct or pointer to struct. If the first
|
||||
// member has the "sshtype" tag set to a number in decimal, that
|
||||
// number is prepended to the result. If the last of member has the
|
||||
// "ssh" tag set to "rest", its contents are appended to the output.
|
||||
func Marshal(msg interface{}) []byte { |
||||
out := make([]byte, 0, 64) |
||||
return marshalStruct(out, msg) |
||||
} |
||||
|
||||
func marshalStruct(out []byte, msg interface{}) []byte { |
||||
v := reflect.Indirect(reflect.ValueOf(msg)) |
||||
msgType := typeTag(v.Type()) |
||||
if msgType > 0 { |
||||
out = append(out, msgType) |
||||
} |
||||
|
||||
for i, n := 0, v.NumField(); i < n; i++ { |
||||
field := v.Field(i) |
||||
switch t := field.Type(); t.Kind() { |
||||
case reflect.Bool: |
||||
var v uint8 |
||||
if field.Bool() { |
||||
v = 1 |
||||
} |
||||
out = append(out, v) |
||||
case reflect.Array: |
||||
if t.Elem().Kind() != reflect.Uint8 { |
||||
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) |
||||
} |
||||
for j, l := 0, t.Len(); j < l; j++ { |
||||
out = append(out, uint8(field.Index(j).Uint())) |
||||
} |
||||
case reflect.Uint32: |
||||
out = appendU32(out, uint32(field.Uint())) |
||||
case reflect.Uint64: |
||||
out = appendU64(out, uint64(field.Uint())) |
||||
case reflect.Uint8: |
||||
out = append(out, uint8(field.Uint())) |
||||
case reflect.String: |
||||
s := field.String() |
||||
out = appendInt(out, len(s)) |
||||
out = append(out, s...) |
||||
case reflect.Slice: |
||||
switch t.Elem().Kind() { |
||||
case reflect.Uint8: |
||||
if v.Type().Field(i).Tag.Get("ssh") != "rest" { |
||||
out = appendInt(out, field.Len()) |
||||
} |
||||
out = append(out, field.Bytes()...) |
||||
case reflect.String: |
||||
offset := len(out) |
||||
out = appendU32(out, 0) |
||||
if n := field.Len(); n > 0 { |
||||
for j := 0; j < n; j++ { |
||||
f := field.Index(j) |
||||
if j != 0 { |
||||
out = append(out, ',') |
||||
} |
||||
out = append(out, f.String()...) |
||||
} |
||||
// overwrite length value
|
||||
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) |
||||
} |
||||
default: |
||||
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) |
||||
} |
||||
case reflect.Ptr: |
||||
if t == bigIntType { |
||||
var n *big.Int |
||||
nValue := reflect.ValueOf(&n) |
||||
nValue.Elem().Set(field) |
||||
needed := intLength(n) |
||||
oldLength := len(out) |
||||
|
||||
if cap(out)-len(out) < needed { |
||||
newOut := make([]byte, len(out), 2*(len(out)+needed)) |
||||
copy(newOut, out) |
||||
out = newOut |
||||
} |
||||
out = out[:oldLength+needed] |
||||
marshalInt(out[oldLength:], n) |
||||
} else { |
||||
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) |
||||
} |
||||
} |
||||
} |
||||
|
||||
return out |
||||
} |
||||
|
||||
var bigOne = big.NewInt(1) |
||||
|
||||
func parseString(in []byte) (out, rest []byte, ok bool) { |
||||
if len(in) < 4 { |
||||
return |
||||
} |
||||
length := binary.BigEndian.Uint32(in) |
||||
in = in[4:] |
||||
if uint32(len(in)) < length { |
||||
return |
||||
} |
||||
out = in[:length] |
||||
rest = in[length:] |
||||
ok = true |
||||
return |
||||
} |
||||
|
||||
var ( |
||||
comma = []byte{','} |
||||
emptyNameList = []string{} |
||||
) |
||||
|
||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) { |
||||
contents, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
if len(contents) == 0 { |
||||
out = emptyNameList |
||||
return |
||||
} |
||||
parts := bytes.Split(contents, comma) |
||||
out = make([]string, len(parts)) |
||||
for i, part := range parts { |
||||
out[i] = string(part) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { |
||||
contents, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
out = new(big.Int) |
||||
|
||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 { |
||||
// This is a negative number
|
||||
notBytes := make([]byte, len(contents)) |
||||
for i := range notBytes { |
||||
notBytes[i] = ^contents[i] |
||||
} |
||||
out.SetBytes(notBytes) |
||||
out.Add(out, bigOne) |
||||
out.Neg(out) |
||||
} else { |
||||
// Positive number
|
||||
out.SetBytes(contents) |
||||
} |
||||
ok = true |
||||
return |
||||
} |
||||
|
||||
func parseUint32(in []byte) (uint32, []byte, bool) { |
||||
if len(in) < 4 { |
||||
return 0, nil, false |
||||
} |
||||
return binary.BigEndian.Uint32(in), in[4:], true |
||||
} |
||||
|
||||
func parseUint64(in []byte) (uint64, []byte, bool) { |
||||
if len(in) < 8 { |
||||
return 0, nil, false |
||||
} |
||||
return binary.BigEndian.Uint64(in), in[8:], true |
||||
} |
||||
|
||||
func intLength(n *big.Int) int { |
||||
length := 4 /* length bytes */ |
||||
if n.Sign() < 0 { |
||||
nMinus1 := new(big.Int).Neg(n) |
||||
nMinus1.Sub(nMinus1, bigOne) |
||||
bitLen := nMinus1.BitLen() |
||||
if bitLen%8 == 0 { |
||||
// The number will need 0xff padding
|
||||
length++ |
||||
} |
||||
length += (bitLen + 7) / 8 |
||||
} else if n.Sign() == 0 { |
||||
// A zero is the zero length string
|
||||
} else { |
||||
bitLen := n.BitLen() |
||||
if bitLen%8 == 0 { |
||||
// The number will need 0x00 padding
|
||||
length++ |
||||
} |
||||
length += (bitLen + 7) / 8 |
||||
} |
||||
|
||||
return length |
||||
} |
||||
|
||||
func marshalUint32(to []byte, n uint32) []byte { |
||||
binary.BigEndian.PutUint32(to, n) |
||||
return to[4:] |
||||
} |
||||
|
||||
func marshalUint64(to []byte, n uint64) []byte { |
||||
binary.BigEndian.PutUint64(to, n) |
||||
return to[8:] |
||||
} |
||||
|
||||
func marshalInt(to []byte, n *big.Int) []byte { |
||||
lengthBytes := to |
||||
to = to[4:] |
||||
length := 0 |
||||
|
||||
if n.Sign() < 0 { |
||||
// A negative number has to be converted to two's-complement
|
||||
// form. So we'll subtract 1 and invert. If the
|
||||
// most-significant-bit isn't set then we'll need to pad the
|
||||
// beginning with 0xff in order to keep the number negative.
|
||||
nMinus1 := new(big.Int).Neg(n) |
||||
nMinus1.Sub(nMinus1, bigOne) |
||||
bytes := nMinus1.Bytes() |
||||
for i := range bytes { |
||||
bytes[i] ^= 0xff |
||||
} |
||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 { |
||||
to[0] = 0xff |
||||
to = to[1:] |
||||
length++ |
||||
} |
||||
nBytes := copy(to, bytes) |
||||
to = to[nBytes:] |
||||
length += nBytes |
||||
} else if n.Sign() == 0 { |
||||
// A zero is the zero length string
|
||||
} else { |
||||
bytes := n.Bytes() |
||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 { |
||||
// We'll have to pad this with a 0x00 in order to
|
||||
// stop it looking like a negative number.
|
||||
to[0] = 0 |
||||
to = to[1:] |
||||
length++ |
||||
} |
||||
nBytes := copy(to, bytes) |
||||
to = to[nBytes:] |
||||
length += nBytes |
||||
} |
||||
|
||||
lengthBytes[0] = byte(length >> 24) |
||||
lengthBytes[1] = byte(length >> 16) |
||||
lengthBytes[2] = byte(length >> 8) |
||||
lengthBytes[3] = byte(length) |
||||
return to |
||||
} |
||||
|
||||
func writeInt(w io.Writer, n *big.Int) { |
||||
length := intLength(n) |
||||
buf := make([]byte, length) |
||||
marshalInt(buf, n) |
||||
w.Write(buf) |
||||
} |
||||
|
||||
func writeString(w io.Writer, s []byte) { |
||||
var lengthBytes [4]byte |
||||
lengthBytes[0] = byte(len(s) >> 24) |
||||
lengthBytes[1] = byte(len(s) >> 16) |
||||
lengthBytes[2] = byte(len(s) >> 8) |
||||
lengthBytes[3] = byte(len(s)) |
||||
w.Write(lengthBytes[:]) |
||||
w.Write(s) |
||||
} |
||||
|
||||
func stringLength(n int) int { |
||||
return 4 + n |
||||
} |
||||
|
||||
func marshalString(to []byte, s []byte) []byte { |
||||
to[0] = byte(len(s) >> 24) |
||||
to[1] = byte(len(s) >> 16) |
||||
to[2] = byte(len(s) >> 8) |
||||
to[3] = byte(len(s)) |
||||
to = to[4:] |
||||
copy(to, s) |
||||
return to[len(s):] |
||||
} |
||||
|
||||
var bigIntType = reflect.TypeOf((*big.Int)(nil)) |
||||
|
||||
// Decode a packet into its corresponding message.
|
||||
func decode(packet []byte) (interface{}, error) { |
||||
var msg interface{} |
||||
switch packet[0] { |
||||
case msgDisconnect: |
||||
msg = new(disconnectMsg) |
||||
case msgServiceRequest: |
||||
msg = new(serviceRequestMsg) |
||||
case msgServiceAccept: |
||||
msg = new(serviceAcceptMsg) |
||||
case msgKexInit: |
||||
msg = new(kexInitMsg) |
||||
case msgKexDHInit: |
||||
msg = new(kexDHInitMsg) |
||||
case msgKexDHReply: |
||||
msg = new(kexDHReplyMsg) |
||||
case msgUserAuthRequest: |
||||
msg = new(userAuthRequestMsg) |
||||
case msgUserAuthFailure: |
||||
msg = new(userAuthFailureMsg) |
||||
case msgUserAuthPubKeyOk: |
||||
msg = new(userAuthPubKeyOkMsg) |
||||
case msgGlobalRequest: |
||||
msg = new(globalRequestMsg) |
||||
case msgRequestSuccess: |
||||
msg = new(globalRequestSuccessMsg) |
||||
case msgRequestFailure: |
||||
msg = new(globalRequestFailureMsg) |
||||
case msgChannelOpen: |
||||
msg = new(channelOpenMsg) |
||||
case msgChannelOpenConfirm: |
||||
msg = new(channelOpenConfirmMsg) |
||||
case msgChannelOpenFailure: |
||||
msg = new(channelOpenFailureMsg) |
||||
case msgChannelWindowAdjust: |
||||
msg = new(windowAdjustMsg) |
||||
case msgChannelEOF: |
||||
msg = new(channelEOFMsg) |
||||
case msgChannelClose: |
||||
msg = new(channelCloseMsg) |
||||
case msgChannelRequest: |
||||
msg = new(channelRequestMsg) |
||||
case msgChannelSuccess: |
||||
msg = new(channelRequestSuccessMsg) |
||||
case msgChannelFailure: |
||||
msg = new(channelRequestFailureMsg) |
||||
default: |
||||
return nil, unexpectedMessageError(0, packet[0]) |
||||
} |
||||
if err := Unmarshal(packet, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return msg, nil |
||||
} |
@ -1,254 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"math/big" |
||||
"math/rand" |
||||
"reflect" |
||||
"testing" |
||||
"testing/quick" |
||||
) |
||||
|
||||
var intLengthTests = []struct { |
||||
val, length int |
||||
}{ |
||||
{0, 4 + 0}, |
||||
{1, 4 + 1}, |
||||
{127, 4 + 1}, |
||||
{128, 4 + 2}, |
||||
{-1, 4 + 1}, |
||||
} |
||||
|
||||
func TestIntLength(t *testing.T) { |
||||
for _, test := range intLengthTests { |
||||
v := new(big.Int).SetInt64(int64(test.val)) |
||||
length := intLength(v) |
||||
if length != test.length { |
||||
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type msgAllTypes struct { |
||||
Bool bool `sshtype:"21"` |
||||
Array [16]byte |
||||
Uint64 uint64 |
||||
Uint32 uint32 |
||||
Uint8 uint8 |
||||
String string |
||||
Strings []string |
||||
Bytes []byte |
||||
Int *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { |
||||
m := &msgAllTypes{} |
||||
m.Bool = rand.Intn(2) == 1 |
||||
randomBytes(m.Array[:], rand) |
||||
m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) |
||||
m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) |
||||
m.Uint8 = uint8(rand.Intn(1 << 8)) |
||||
m.String = string(m.Array[:]) |
||||
m.Strings = randomNameList(rand) |
||||
m.Bytes = m.Array[:] |
||||
m.Int = randomInt(rand) |
||||
m.Rest = m.Array[:] |
||||
return reflect.ValueOf(m) |
||||
} |
||||
|
||||
func TestMarshalUnmarshal(t *testing.T) { |
||||
rand := rand.New(rand.NewSource(0)) |
||||
iface := &msgAllTypes{} |
||||
ty := reflect.ValueOf(iface).Type() |
||||
|
||||
n := 100 |
||||
if testing.Short() { |
||||
n = 5 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
v, ok := quick.Value(ty, rand) |
||||
if !ok { |
||||
t.Errorf("failed to create value") |
||||
break |
||||
} |
||||
|
||||
m1 := v.Elem().Interface() |
||||
m2 := iface |
||||
|
||||
marshaled := Marshal(m1) |
||||
if err := Unmarshal(marshaled, m2); err != nil { |
||||
t.Errorf("Unmarshal %#v: %s", m1, err) |
||||
break |
||||
} |
||||
|
||||
if !reflect.DeepEqual(v.Interface(), m2) { |
||||
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestUnmarshalEmptyPacket(t *testing.T) { |
||||
var b []byte |
||||
var m channelRequestSuccessMsg |
||||
if err := Unmarshal(b, &m); err == nil { |
||||
t.Fatalf("unmarshal of empty slice succeeded") |
||||
} |
||||
} |
||||
|
||||
func TestUnmarshalUnexpectedPacket(t *testing.T) { |
||||
type S struct { |
||||
I uint32 `sshtype:"43"` |
||||
S string |
||||
B bool |
||||
} |
||||
|
||||
s := S{11, "hello", true} |
||||
packet := Marshal(s) |
||||
packet[0] = 42 |
||||
roundtrip := S{} |
||||
err := Unmarshal(packet, &roundtrip) |
||||
if err == nil { |
||||
t.Fatal("expected error, not nil") |
||||
} |
||||
} |
||||
|
||||
func TestMarshalPtr(t *testing.T) { |
||||
s := struct { |
||||
S string |
||||
}{"hello"} |
||||
|
||||
m1 := Marshal(s) |
||||
m2 := Marshal(&s) |
||||
if !bytes.Equal(m1, m2) { |
||||
t.Errorf("got %q, want %q for marshaled pointer", m2, m1) |
||||
} |
||||
} |
||||
|
||||
func TestBareMarshalUnmarshal(t *testing.T) { |
||||
type S struct { |
||||
I uint32 |
||||
S string |
||||
B bool |
||||
} |
||||
|
||||
s := S{42, "hello", true} |
||||
packet := Marshal(s) |
||||
roundtrip := S{} |
||||
Unmarshal(packet, &roundtrip) |
||||
|
||||
if !reflect.DeepEqual(s, roundtrip) { |
||||
t.Errorf("got %#v, want %#v", roundtrip, s) |
||||
} |
||||
} |
||||
|
||||
func TestBareMarshal(t *testing.T) { |
||||
type S2 struct { |
||||
I uint32 |
||||
} |
||||
s := S2{42} |
||||
packet := Marshal(s) |
||||
i, rest, ok := parseUint32(packet) |
||||
if len(rest) > 0 || !ok { |
||||
t.Errorf("parseInt(%q): parse error", packet) |
||||
} |
||||
if i != s.I { |
||||
t.Errorf("got %d, want %d", i, s.I) |
||||
} |
||||
} |
||||
|
||||
func TestUnmarshalShortKexInitPacket(t *testing.T) { |
||||
// This used to panic.
|
||||
// Issue 11348
|
||||
packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} |
||||
kim := &kexInitMsg{} |
||||
if err := Unmarshal(packet, kim); err == nil { |
||||
t.Error("truncated packet unmarshaled without error") |
||||
} |
||||
} |
||||
|
||||
func randomBytes(out []byte, rand *rand.Rand) { |
||||
for i := 0; i < len(out); i++ { |
||||
out[i] = byte(rand.Int31()) |
||||
} |
||||
} |
||||
|
||||
func randomNameList(rand *rand.Rand) []string { |
||||
ret := make([]string, rand.Int31()&15) |
||||
for i := range ret { |
||||
s := make([]byte, 1+(rand.Int31()&15)) |
||||
for j := range s { |
||||
s[j] = 'a' + uint8(rand.Int31()&15) |
||||
} |
||||
ret[i] = string(s) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
func randomInt(rand *rand.Rand) *big.Int { |
||||
return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) |
||||
} |
||||
|
||||
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
||||
ki := &kexInitMsg{} |
||||
randomBytes(ki.Cookie[:], rand) |
||||
ki.KexAlgos = randomNameList(rand) |
||||
ki.ServerHostKeyAlgos = randomNameList(rand) |
||||
ki.CiphersClientServer = randomNameList(rand) |
||||
ki.CiphersServerClient = randomNameList(rand) |
||||
ki.MACsClientServer = randomNameList(rand) |
||||
ki.MACsServerClient = randomNameList(rand) |
||||
ki.CompressionClientServer = randomNameList(rand) |
||||
ki.CompressionServerClient = randomNameList(rand) |
||||
ki.LanguagesClientServer = randomNameList(rand) |
||||
ki.LanguagesServerClient = randomNameList(rand) |
||||
if rand.Int31()&1 == 1 { |
||||
ki.FirstKexFollows = true |
||||
} |
||||
return reflect.ValueOf(ki) |
||||
} |
||||
|
||||
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { |
||||
dhi := &kexDHInitMsg{} |
||||
dhi.X = randomInt(rand) |
||||
return reflect.ValueOf(dhi) |
||||
} |
||||
|
||||
var ( |
||||
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
||||
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() |
||||
|
||||
_kexInit = Marshal(_kexInitMsg) |
||||
_kexDHInit = Marshal(_kexDHInitMsg) |
||||
) |
||||
|
||||
func BenchmarkMarshalKexInitMsg(b *testing.B) { |
||||
for i := 0; i < b.N; i++ { |
||||
Marshal(_kexInitMsg) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkUnmarshalKexInitMsg(b *testing.B) { |
||||
m := new(kexInitMsg) |
||||
for i := 0; i < b.N; i++ { |
||||
Unmarshal(_kexInit, m) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkMarshalKexDHInitMsg(b *testing.B) { |
||||
for i := 0; i < b.N; i++ { |
||||
Marshal(_kexDHInitMsg) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { |
||||
m := new(kexDHInitMsg) |
||||
for i := 0; i < b.N; i++ { |
||||
Unmarshal(_kexDHInit, m) |
||||
} |
||||
} |
@ -1,356 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"sync" |
||||
"sync/atomic" |
||||
) |
||||
|
||||
// debugMux, if set, causes messages in the connection protocol to be
|
||||
// logged.
|
||||
const debugMux = false |
||||
|
||||
// chanList is a thread safe channel list.
|
||||
type chanList struct { |
||||
// protects concurrent access to chans
|
||||
sync.Mutex |
||||
|
||||
// chans are indexed by the local id of the channel, which the
|
||||
// other side should send in the PeersId field.
|
||||
chans []*channel |
||||
|
||||
// This is a debugging aid: it offsets all IDs by this
|
||||
// amount. This helps distinguish otherwise identical
|
||||
// server/client muxes
|
||||
offset uint32 |
||||
} |
||||
|
||||
// Assigns a channel ID to the given channel.
|
||||
func (c *chanList) add(ch *channel) uint32 { |
||||
c.Lock() |
||||
defer c.Unlock() |
||||
for i := range c.chans { |
||||
if c.chans[i] == nil { |
||||
c.chans[i] = ch |
||||
return uint32(i) + c.offset |
||||
} |
||||
} |
||||
c.chans = append(c.chans, ch) |
||||
return uint32(len(c.chans)-1) + c.offset |
||||
} |
||||
|
||||
// getChan returns the channel for the given ID.
|
||||
func (c *chanList) getChan(id uint32) *channel { |
||||
id -= c.offset |
||||
|
||||
c.Lock() |
||||
defer c.Unlock() |
||||
if id < uint32(len(c.chans)) { |
||||
return c.chans[id] |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *chanList) remove(id uint32) { |
||||
id -= c.offset |
||||
c.Lock() |
||||
if id < uint32(len(c.chans)) { |
||||
c.chans[id] = nil |
||||
} |
||||
c.Unlock() |
||||
} |
||||
|
||||
// dropAll forgets all channels it knows, returning them in a slice.
|
||||
func (c *chanList) dropAll() []*channel { |
||||
c.Lock() |
||||
defer c.Unlock() |
||||
var r []*channel |
||||
|
||||
for _, ch := range c.chans { |
||||
if ch == nil { |
||||
continue |
||||
} |
||||
r = append(r, ch) |
||||
} |
||||
c.chans = nil |
||||
return r |
||||
} |
||||
|
||||
// mux represents the state for the SSH connection protocol, which
|
||||
// multiplexes many channels onto a single packet transport.
|
||||
type mux struct { |
||||
conn packetConn |
||||
chanList chanList |
||||
|
||||
incomingChannels chan NewChannel |
||||
|
||||
globalSentMu sync.Mutex |
||||
globalResponses chan interface{} |
||||
incomingRequests chan *Request |
||||
|
||||
errCond *sync.Cond |
||||
err error |
||||
} |
||||
|
||||
// When debugging, each new chanList instantiation has a different
|
||||
// offset.
|
||||
var globalOff uint32 |
||||
|
||||
func (m *mux) Wait() error { |
||||
m.errCond.L.Lock() |
||||
defer m.errCond.L.Unlock() |
||||
for m.err == nil { |
||||
m.errCond.Wait() |
||||
} |
||||
return m.err |
||||
} |
||||
|
||||
// newMux returns a mux that runs over the given connection.
|
||||
func newMux(p packetConn) *mux { |
||||
m := &mux{ |
||||
conn: p, |
||||
incomingChannels: make(chan NewChannel, 16), |
||||
globalResponses: make(chan interface{}, 1), |
||||
incomingRequests: make(chan *Request, 16), |
||||
errCond: newCond(), |
||||
} |
||||
if debugMux { |
||||
m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
||||
} |
||||
|
||||
go m.loop() |
||||
return m |
||||
} |
||||
|
||||
func (m *mux) sendMessage(msg interface{}) error { |
||||
p := Marshal(msg) |
||||
return m.conn.writePacket(p) |
||||
} |
||||
|
||||
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { |
||||
if wantReply { |
||||
m.globalSentMu.Lock() |
||||
defer m.globalSentMu.Unlock() |
||||
} |
||||
|
||||
if err := m.sendMessage(globalRequestMsg{ |
||||
Type: name, |
||||
WantReply: wantReply, |
||||
Data: payload, |
||||
}); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if !wantReply { |
||||
return false, nil, nil |
||||
} |
||||
|
||||
msg, ok := <-m.globalResponses |
||||
if !ok { |
||||
return false, nil, io.EOF |
||||
} |
||||
switch msg := msg.(type) { |
||||
case *globalRequestFailureMsg: |
||||
return false, msg.Data, nil |
||||
case *globalRequestSuccessMsg: |
||||
return true, msg.Data, nil |
||||
default: |
||||
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) |
||||
} |
||||
} |
||||
|
||||
// ackRequest must be called after processing a global request that
|
||||
// has WantReply set.
|
||||
func (m *mux) ackRequest(ok bool, data []byte) error { |
||||
if ok { |
||||
return m.sendMessage(globalRequestSuccessMsg{Data: data}) |
||||
} |
||||
return m.sendMessage(globalRequestFailureMsg{Data: data}) |
||||
} |
||||
|
||||
// TODO(hanwen): Disconnect is a transport layer message. We should
|
||||
// probably send and receive Disconnect somewhere in the transport
|
||||
// code.
|
||||
|
||||
// Disconnect sends a disconnect message.
|
||||
func (m *mux) Disconnect(reason uint32, message string) error { |
||||
return m.sendMessage(disconnectMsg{ |
||||
Reason: reason, |
||||
Message: message, |
||||
}) |
||||
} |
||||
|
||||
func (m *mux) Close() error { |
||||
return m.conn.Close() |
||||
} |
||||
|
||||
// loop runs the connection machine. It will process packets until an
|
||||
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
||||
func (m *mux) loop() { |
||||
var err error |
||||
for err == nil { |
||||
err = m.onePacket() |
||||
} |
||||
|
||||
for _, ch := range m.chanList.dropAll() { |
||||
ch.close() |
||||
} |
||||
|
||||
close(m.incomingChannels) |
||||
close(m.incomingRequests) |
||||
close(m.globalResponses) |
||||
|
||||
m.conn.Close() |
||||
|
||||
m.errCond.L.Lock() |
||||
m.err = err |
||||
m.errCond.Broadcast() |
||||
m.errCond.L.Unlock() |
||||
|
||||
if debugMux { |
||||
log.Println("loop exit", err) |
||||
} |
||||
} |
||||
|
||||
// onePacket reads and processes one packet.
|
||||
func (m *mux) onePacket() error { |
||||
packet, err := m.conn.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if debugMux { |
||||
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { |
||||
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) |
||||
} else { |
||||
p, _ := decode(packet) |
||||
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) |
||||
} |
||||
} |
||||
|
||||
switch packet[0] { |
||||
case msgNewKeys: |
||||
// Ignore notification of key change.
|
||||
return nil |
||||
case msgDisconnect: |
||||
return m.handleDisconnect(packet) |
||||
case msgChannelOpen: |
||||
return m.handleChannelOpen(packet) |
||||
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
||||
return m.handleGlobalPacket(packet) |
||||
} |
||||
|
||||
// assume a channel packet.
|
||||
if len(packet) < 5 { |
||||
return parseError(packet[0]) |
||||
} |
||||
id := binary.BigEndian.Uint32(packet[1:]) |
||||
ch := m.chanList.getChan(id) |
||||
if ch == nil { |
||||
return fmt.Errorf("ssh: invalid channel %d", id) |
||||
} |
||||
|
||||
return ch.handlePacket(packet) |
||||
} |
||||
|
||||
func (m *mux) handleDisconnect(packet []byte) error { |
||||
var d disconnectMsg |
||||
if err := Unmarshal(packet, &d); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if debugMux { |
||||
log.Printf("caught disconnect: %v", d) |
||||
} |
||||
return &d |
||||
} |
||||
|
||||
func (m *mux) handleGlobalPacket(packet []byte) error { |
||||
msg, err := decode(packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *globalRequestMsg: |
||||
m.incomingRequests <- &Request{ |
||||
Type: msg.Type, |
||||
WantReply: msg.WantReply, |
||||
Payload: msg.Data, |
||||
mux: m, |
||||
} |
||||
case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
||||
m.globalResponses <- msg |
||||
default: |
||||
panic(fmt.Sprintf("not a global message %#v", msg)) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// handleChannelOpen schedules a channel to be Accept()ed.
|
||||
func (m *mux) handleChannelOpen(packet []byte) error { |
||||
var msg channelOpenMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||
failMsg := channelOpenFailureMsg{ |
||||
PeersId: msg.PeersId, |
||||
Reason: ConnectionFailed, |
||||
Message: "invalid request", |
||||
Language: "en_US.UTF-8", |
||||
} |
||||
return m.sendMessage(failMsg) |
||||
} |
||||
|
||||
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) |
||||
c.remoteId = msg.PeersId |
||||
c.maxRemotePayload = msg.MaxPacketSize |
||||
c.remoteWin.add(msg.PeersWindow) |
||||
m.incomingChannels <- c |
||||
return nil |
||||
} |
||||
|
||||
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { |
||||
ch, err := m.openChannel(chanType, extra) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return ch, ch.incomingRequests, nil |
||||
} |
||||
|
||||
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { |
||||
ch := m.newChannel(chanType, channelOutbound, extra) |
||||
|
||||
ch.maxIncomingPayload = channelMaxPacket |
||||
|
||||
open := channelOpenMsg{ |
||||
ChanType: chanType, |
||||
PeersWindow: ch.myWindow, |
||||
MaxPacketSize: ch.maxIncomingPayload, |
||||
TypeSpecificData: extra, |
||||
PeersId: ch.localId, |
||||
} |
||||
if err := m.sendMessage(open); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := (<-ch.msg).(type) { |
||||
case *channelOpenConfirmMsg: |
||||
return ch, nil |
||||
case *channelOpenFailureMsg: |
||||
return nil, &OpenChannelError{msg.Reason, msg.Message} |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) |
||||
} |
||||
} |
@ -1,525 +0,0 @@ |
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"io/ioutil" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
func muxPair() (*mux, *mux) { |
||||
a, b := memPipe() |
||||
|
||||
s := newMux(a) |
||||
c := newMux(b) |
||||
|
||||
return s, c |
||||
} |
||||
|
||||
// Returns both ends of a channel, and the mux for the the 2nd
|
||||
// channel.
|
||||
func channelPair(t *testing.T) (*channel, *channel, *mux) { |
||||
c, s := muxPair() |
||||
|
||||
res := make(chan *channel, 1) |
||||
go func() { |
||||
newCh, ok := <-s.incomingChannels |
||||
if !ok { |
||||
t.Fatalf("No incoming channel") |
||||
} |
||||
if newCh.ChannelType() != "chan" { |
||||
t.Fatalf("got type %q want chan", newCh.ChannelType()) |
||||
} |
||||
ch, _, err := newCh.Accept() |
||||
if err != nil { |
||||
t.Fatalf("Accept %v", err) |
||||
} |
||||
res <- ch.(*channel) |
||||
}() |
||||
|
||||
ch, err := c.openChannel("chan", nil) |
||||
if err != nil { |
||||
t.Fatalf("OpenChannel: %v", err) |
||||
} |
||||
|
||||
return <-res, ch, c |
||||
} |
||||
|
||||
// Test that stderr and stdout can be addressed from different
|
||||
// goroutines. This is intended for use with the race detector.
|
||||
func TestMuxChannelExtendedThreadSafety(t *testing.T) { |
||||
writer, reader, mux := channelPair(t) |
||||
defer writer.Close() |
||||
defer reader.Close() |
||||
defer mux.Close() |
||||
|
||||
var wr, rd sync.WaitGroup |
||||
magic := "hello world" |
||||
|
||||
wr.Add(2) |
||||
go func() { |
||||
io.WriteString(writer, magic) |
||||
wr.Done() |
||||
}() |
||||
go func() { |
||||
io.WriteString(writer.Stderr(), magic) |
||||
wr.Done() |
||||
}() |
||||
|
||||
rd.Add(2) |
||||
go func() { |
||||
c, err := ioutil.ReadAll(reader) |
||||
if string(c) != magic { |
||||
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) |
||||
} |
||||
rd.Done() |
||||
}() |
||||
go func() { |
||||
c, err := ioutil.ReadAll(reader.Stderr()) |
||||
if string(c) != magic { |
||||
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) |
||||
} |
||||
rd.Done() |
||||
}() |
||||
|
||||
wr.Wait() |
||||
writer.CloseWrite() |
||||
rd.Wait() |
||||
} |
||||
|
||||
func TestMuxReadWrite(t *testing.T) { |
||||
s, c, mux := channelPair(t) |
||||
defer s.Close() |
||||
defer c.Close() |
||||
defer mux.Close() |
||||
|
||||
magic := "hello world" |
||||
magicExt := "hello stderr" |
||||
go func() { |
||||
_, err := s.Write([]byte(magic)) |
||||
if err != nil { |
||||
t.Fatalf("Write: %v", err) |
||||
} |
||||
_, err = s.Extended(1).Write([]byte(magicExt)) |
||||
if err != nil { |
||||
t.Fatalf("Write: %v", err) |
||||
} |
||||
err = s.Close() |
||||
if err != nil { |
||||
t.Fatalf("Close: %v", err) |
||||
} |
||||
}() |
||||
|
||||
var buf [1024]byte |
||||
n, err := c.Read(buf[:]) |
||||
if err != nil { |
||||
t.Fatalf("server Read: %v", err) |
||||
} |
||||
got := string(buf[:n]) |
||||
if got != magic { |
||||
t.Fatalf("server: got %q want %q", got, magic) |
||||
} |
||||
|
||||
n, err = c.Extended(1).Read(buf[:]) |
||||
if err != nil { |
||||
t.Fatalf("server Read: %v", err) |
||||
} |
||||
|
||||
got = string(buf[:n]) |
||||
if got != magicExt { |
||||
t.Fatalf("server: got %q want %q", got, magic) |
||||
} |
||||
} |
||||
|
||||
func TestMuxChannelOverflow(t *testing.T) { |
||||
reader, writer, mux := channelPair(t) |
||||
defer reader.Close() |
||||
defer writer.Close() |
||||
defer mux.Close() |
||||
|
||||
wDone := make(chan int, 1) |
||||
go func() { |
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||
t.Errorf("could not fill window: %v", err) |
||||
} |
||||
writer.Write(make([]byte, 1)) |
||||
wDone <- 1 |
||||
}() |
||||
writer.remoteWin.waitWriterBlocked() |
||||
|
||||
// Send 1 byte.
|
||||
packet := make([]byte, 1+4+4+1) |
||||
packet[0] = msgChannelData |
||||
marshalUint32(packet[1:], writer.remoteId) |
||||
marshalUint32(packet[5:], uint32(1)) |
||||
packet[9] = 42 |
||||
|
||||
if err := writer.mux.conn.writePacket(packet); err != nil { |
||||
t.Errorf("could not send packet") |
||||
} |
||||
if _, err := reader.SendRequest("hello", true, nil); err == nil { |
||||
t.Errorf("SendRequest succeeded.") |
||||
} |
||||
<-wDone |
||||
} |
||||
|
||||
func TestMuxChannelCloseWriteUnblock(t *testing.T) { |
||||
reader, writer, mux := channelPair(t) |
||||
defer reader.Close() |
||||
defer writer.Close() |
||||
defer mux.Close() |
||||
|
||||
wDone := make(chan int, 1) |
||||
go func() { |
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||
t.Errorf("could not fill window: %v", err) |
||||
} |
||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { |
||||
t.Errorf("got %v, want EOF for unblock write", err) |
||||
} |
||||
wDone <- 1 |
||||
}() |
||||
|
||||
writer.remoteWin.waitWriterBlocked() |
||||
reader.Close() |
||||
<-wDone |
||||
} |
||||
|
||||
func TestMuxConnectionCloseWriteUnblock(t *testing.T) { |
||||
reader, writer, mux := channelPair(t) |
||||
defer reader.Close() |
||||
defer writer.Close() |
||||
defer mux.Close() |
||||
|
||||
wDone := make(chan int, 1) |
||||
go func() { |
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { |
||||
t.Errorf("could not fill window: %v", err) |
||||
} |
||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { |
||||
t.Errorf("got %v, want EOF for unblock write", err) |
||||
} |
||||
wDone <- 1 |
||||
}() |
||||
|
||||
writer.remoteWin.waitWriterBlocked() |
||||
mux.Close() |
||||
<-wDone |
||||
} |
||||
|
||||
func TestMuxReject(t *testing.T) { |
||||
client, server := muxPair() |
||||
defer server.Close() |
||||
defer client.Close() |
||||
|
||||
go func() { |
||||
ch, ok := <-server.incomingChannels |
||||
if !ok { |
||||
t.Fatalf("Accept") |
||||
} |
||||
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { |
||||
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) |
||||
} |
||||
ch.Reject(RejectionReason(42), "message") |
||||
}() |
||||
|
||||
ch, err := client.openChannel("ch", []byte("extra")) |
||||
if ch != nil { |
||||
t.Fatal("openChannel not rejected") |
||||
} |
||||
|
||||
ocf, ok := err.(*OpenChannelError) |
||||
if !ok { |
||||
t.Errorf("got %#v want *OpenChannelError", err) |
||||
} else if ocf.Reason != 42 || ocf.Message != "message" { |
||||
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") |
||||
} |
||||
|
||||
want := "ssh: rejected: unknown reason 42 (message)" |
||||
if err.Error() != want { |
||||
t.Errorf("got %q, want %q", err.Error(), want) |
||||
} |
||||
} |
||||
|
||||
func TestMuxChannelRequest(t *testing.T) { |
||||
client, server, mux := channelPair(t) |
||||
defer server.Close() |
||||
defer client.Close() |
||||
defer mux.Close() |
||||
|
||||
var received int |
||||
var wg sync.WaitGroup |
||||
wg.Add(1) |
||||
go func() { |
||||
for r := range server.incomingRequests { |
||||
received++ |
||||
r.Reply(r.Type == "yes", nil) |
||||
} |
||||
wg.Done() |
||||
}() |
||||
_, err := client.SendRequest("yes", false, nil) |
||||
if err != nil { |
||||
t.Fatalf("SendRequest: %v", err) |
||||
} |
||||
ok, err := client.SendRequest("yes", true, nil) |
||||
if err != nil { |
||||
t.Fatalf("SendRequest: %v", err) |
||||
} |
||||
|
||||
if !ok { |
||||
t.Errorf("SendRequest(yes): %v", ok) |
||||
|
||||
} |
||||
|
||||
ok, err = client.SendRequest("no", true, nil) |
||||
if err != nil { |
||||
t.Fatalf("SendRequest: %v", err) |
||||
} |
||||
if ok { |
||||
t.Errorf("SendRequest(no): %v", ok) |
||||
|
||||
} |
||||
|
||||
client.Close() |
||||
wg.Wait() |
||||
|
||||
if received != 3 { |
||||
t.Errorf("got %d requests, want %d", received, 3) |
||||
} |
||||
} |
||||
|
||||
func TestMuxGlobalRequest(t *testing.T) { |
||||
clientMux, serverMux := muxPair() |
||||
defer serverMux.Close() |
||||
defer clientMux.Close() |
||||
|
||||
var seen bool |
||||
go func() { |
||||
for r := range serverMux.incomingRequests { |
||||
seen = seen || r.Type == "peek" |
||||
if r.WantReply { |
||||
err := r.Reply(r.Type == "yes", |
||||
append([]byte(r.Type), r.Payload...)) |
||||
if err != nil { |
||||
t.Errorf("AckRequest: %v", err) |
||||
} |
||||
} |
||||
} |
||||
}() |
||||
|
||||
_, _, err := clientMux.SendRequest("peek", false, nil) |
||||
if err != nil { |
||||
t.Errorf("SendRequest: %v", err) |
||||
} |
||||
|
||||
ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) |
||||
if !ok || string(data) != "yesa" || err != nil { |
||||
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
||||
ok, data, err) |
||||
} |
||||
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { |
||||
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
||||
ok, data, err) |
||||
} |
||||
|
||||
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { |
||||
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", |
||||
ok, data, err) |
||||
} |
||||
|
||||
clientMux.Disconnect(0, "") |
||||
if !seen { |
||||
t.Errorf("never saw 'peek' request") |
||||
} |
||||
} |
||||
|
||||
func TestMuxGlobalRequestUnblock(t *testing.T) { |
||||
clientMux, serverMux := muxPair() |
||||
defer serverMux.Close() |
||||
defer clientMux.Close() |
||||
|
||||
result := make(chan error, 1) |
||||
go func() { |
||||
_, _, err := clientMux.SendRequest("hello", true, nil) |
||||
result <- err |
||||
}() |
||||
|
||||
<-serverMux.incomingRequests |
||||
serverMux.conn.Close() |
||||
err := <-result |
||||
|
||||
if err != io.EOF { |
||||
t.Errorf("want EOF, got %v", io.EOF) |
||||
} |
||||
} |
||||
|
||||
func TestMuxChannelRequestUnblock(t *testing.T) { |
||||
a, b, connB := channelPair(t) |
||||
defer a.Close() |
||||
defer b.Close() |
||||
defer connB.Close() |
||||
|
||||
result := make(chan error, 1) |
||||
go func() { |
||||
_, err := a.SendRequest("hello", true, nil) |
||||
result <- err |
||||
}() |
||||
|
||||
<-b.incomingRequests |
||||
connB.conn.Close() |
||||
err := <-result |
||||
|
||||
if err != io.EOF { |
||||
t.Errorf("want EOF, got %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestMuxDisconnect(t *testing.T) { |
||||
a, b := muxPair() |
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
go func() { |
||||
for r := range b.incomingRequests { |
||||
r.Reply(true, nil) |
||||
} |
||||
}() |
||||
|
||||
a.Disconnect(42, "whatever") |
||||
ok, _, err := a.SendRequest("hello", true, nil) |
||||
if ok || err == nil { |
||||
t.Errorf("got reply after disconnecting") |
||||
} |
||||
err = b.Wait() |
||||
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { |
||||
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) |
||||
} |
||||
} |
||||
|
||||
func TestMuxCloseChannel(t *testing.T) { |
||||
r, w, mux := channelPair(t) |
||||
defer mux.Close() |
||||
defer r.Close() |
||||
defer w.Close() |
||||
|
||||
result := make(chan error, 1) |
||||
go func() { |
||||
var b [1024]byte |
||||
_, err := r.Read(b[:]) |
||||
result <- err |
||||
}() |
||||
if err := w.Close(); err != nil { |
||||
t.Errorf("w.Close: %v", err) |
||||
} |
||||
|
||||
if _, err := w.Write([]byte("hello")); err != io.EOF { |
||||
t.Errorf("got err %v, want io.EOF after Close", err) |
||||
} |
||||
|
||||
if err := <-result; err != io.EOF { |
||||
t.Errorf("got %v (%T), want io.EOF", err, err) |
||||
} |
||||
} |
||||
|
||||
func TestMuxCloseWriteChannel(t *testing.T) { |
||||
r, w, mux := channelPair(t) |
||||
defer mux.Close() |
||||
|
||||
result := make(chan error, 1) |
||||
go func() { |
||||
var b [1024]byte |
||||
_, err := r.Read(b[:]) |
||||
result <- err |
||||
}() |
||||
if err := w.CloseWrite(); err != nil { |
||||
t.Errorf("w.CloseWrite: %v", err) |
||||
} |
||||
|
||||
if _, err := w.Write([]byte("hello")); err != io.EOF { |
||||
t.Errorf("got err %v, want io.EOF after CloseWrite", err) |
||||
} |
||||
|
||||
if err := <-result; err != io.EOF { |
||||
t.Errorf("got %v (%T), want io.EOF", err, err) |
||||
} |
||||
} |
||||
|
||||
func TestMuxInvalidRecord(t *testing.T) { |
||||
a, b := muxPair() |
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
packet := make([]byte, 1+4+4+1) |
||||
packet[0] = msgChannelData |
||||
marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
||||
marshalUint32(packet[5:], 1) |
||||
packet[9] = 42 |
||||
|
||||
a.conn.writePacket(packet) |
||||
go a.SendRequest("hello", false, nil) |
||||
// 'a' wrote an invalid packet, so 'b' has exited.
|
||||
req, ok := <-b.incomingRequests |
||||
if ok { |
||||
t.Errorf("got request %#v after receiving invalid packet", req) |
||||
} |
||||
} |
||||
|
||||
func TestZeroWindowAdjust(t *testing.T) { |
||||
a, b, mux := channelPair(t) |
||||
defer a.Close() |
||||
defer b.Close() |
||||
defer mux.Close() |
||||
|
||||
go func() { |
||||
io.WriteString(a, "hello") |
||||
// bogus adjust.
|
||||
a.sendMessage(windowAdjustMsg{}) |
||||
io.WriteString(a, "world") |
||||
a.Close() |
||||
}() |
||||
|
||||
want := "helloworld" |
||||
c, _ := ioutil.ReadAll(b) |
||||
if string(c) != want { |
||||
t.Errorf("got %q want %q", c, want) |
||||
} |
||||
} |
||||
|
||||
func TestMuxMaxPacketSize(t *testing.T) { |
||||
a, b, mux := channelPair(t) |
||||
defer a.Close() |
||||
defer b.Close() |
||||
defer mux.Close() |
||||
|
||||
large := make([]byte, a.maxRemotePayload+1) |
||||
packet := make([]byte, 1+4+4+1+len(large)) |
||||
packet[0] = msgChannelData |
||||
marshalUint32(packet[1:], a.remoteId) |
||||
marshalUint32(packet[5:], uint32(len(large))) |
||||
packet[9] = 42 |
||||
|
||||
if err := a.mux.conn.writePacket(packet); err != nil { |
||||
t.Errorf("could not send packet") |
||||
} |
||||
|
||||
go a.SendRequest("hello", false, nil) |
||||
|
||||
_, ok := <-b.incomingRequests |
||||
if ok { |
||||
t.Errorf("connection still alive after receiving large packet.") |
||||
} |
||||
} |
||||
|
||||
// Don't ship code with debug=true.
|
||||
func TestDebug(t *testing.T) { |
||||
if debugMux { |
||||
t.Error("mux debug switched on") |
||||
} |
||||
if debugHandshake { |
||||
t.Error("handshake debug switched on") |
||||
} |
||||
} |
@ -1,493 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
) |
||||
|
||||
// The Permissions type holds fine-grained permissions that are
|
||||
// specific to a user or a specific authentication method for a
|
||||
// user. Permissions, except for "source-address", must be enforced in
|
||||
// the server application layer, after successful authentication. The
|
||||
// Permissions are passed on in ServerConn so a server implementation
|
||||
// can honor them.
|
||||
type Permissions struct { |
||||
// Critical options restrict default permissions. Common
|
||||
// restrictions are "source-address" and "force-command". If
|
||||
// the server cannot enforce the restriction, or does not
|
||||
// recognize it, the user should not authenticate.
|
||||
CriticalOptions map[string]string |
||||
|
||||
// Extensions are extra functionality that the server may
|
||||
// offer on authenticated connections. Common extensions are
|
||||
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
|
||||
// support for an extension does not preclude authenticating a
|
||||
// user.
|
||||
Extensions map[string]string |
||||
} |
||||
|
||||
// ServerConfig holds server specific configuration data.
|
||||
type ServerConfig struct { |
||||
// Config contains configuration shared between client and server.
|
||||
Config |
||||
|
||||
hostKeys []Signer |
||||
|
||||
// NoClientAuth is true if clients are allowed to connect without
|
||||
// authenticating.
|
||||
NoClientAuth bool |
||||
|
||||
// PasswordCallback, if non-nil, is called when a user
|
||||
// attempts to authenticate using a password.
|
||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) |
||||
|
||||
// PublicKeyCallback, if non-nil, is called when a client attempts public
|
||||
// key authentication. It must return true if the given public key is
|
||||
// valid for the given user. For example, see CertChecker.Authenticate.
|
||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||
|
||||
// KeyboardInteractiveCallback, if non-nil, is called when
|
||||
// keyboard-interactive authentication is selected (RFC
|
||||
// 4256). The client object's Challenge function should be
|
||||
// used to query the user. The callback may offer multiple
|
||||
// Challenge rounds. To avoid information leaks, the client
|
||||
// should be presented a challenge even if the user is
|
||||
// unknown.
|
||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) |
||||
|
||||
// AuthLogCallback, if non-nil, is called to log all authentication
|
||||
// attempts.
|
||||
AuthLogCallback func(conn ConnMetadata, method string, err error) |
||||
|
||||
// ServerVersion is the version identification string to
|
||||
// announce in the public handshake.
|
||||
// If empty, a reasonable default is used.
|
||||
ServerVersion string |
||||
} |
||||
|
||||
// AddHostKey adds a private key as a host key. If an existing host
|
||||
// key exists with the same algorithm, it is overwritten. Each server
|
||||
// config must have at least one host key.
|
||||
func (s *ServerConfig) AddHostKey(key Signer) { |
||||
for i, k := range s.hostKeys { |
||||
if k.PublicKey().Type() == key.PublicKey().Type() { |
||||
s.hostKeys[i] = key |
||||
return |
||||
} |
||||
} |
||||
|
||||
s.hostKeys = append(s.hostKeys, key) |
||||
} |
||||
|
||||
// cachedPubKey contains the results of querying whether a public key is
|
||||
// acceptable for a user.
|
||||
type cachedPubKey struct { |
||||
user string |
||||
pubKeyData []byte |
||||
result error |
||||
perms *Permissions |
||||
} |
||||
|
||||
const maxCachedPubKeys = 16 |
||||
|
||||
// pubKeyCache caches tests for public keys. Since SSH clients
|
||||
// will query whether a public key is acceptable before attempting to
|
||||
// authenticate with it, we end up with duplicate queries for public
|
||||
// key validity. The cache only applies to a single ServerConn.
|
||||
type pubKeyCache struct { |
||||
keys []cachedPubKey |
||||
} |
||||
|
||||
// get returns the result for a given user/algo/key tuple.
|
||||
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { |
||||
for _, k := range c.keys { |
||||
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { |
||||
return k, true |
||||
} |
||||
} |
||||
return cachedPubKey{}, false |
||||
} |
||||
|
||||
// add adds the given tuple to the cache.
|
||||
func (c *pubKeyCache) add(candidate cachedPubKey) { |
||||
if len(c.keys) < maxCachedPubKeys { |
||||
c.keys = append(c.keys, candidate) |
||||
} |
||||
} |
||||
|
||||
// ServerConn is an authenticated SSH connection, as seen from the
|
||||
// server
|
||||
type ServerConn struct { |
||||
Conn |
||||
|
||||
// If the succeeding authentication callback returned a
|
||||
// non-nil Permissions pointer, it is stored here.
|
||||
Permissions *Permissions |
||||
} |
||||
|
||||
// NewServerConn starts a new SSH server with c as the underlying
|
||||
// transport. It starts with a handshake and, if the handshake is
|
||||
// unsuccessful, it closes the connection and returns an error. The
|
||||
// Request and NewChannel channels must be serviced, or the connection
|
||||
// will hang.
|
||||
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { |
||||
fullConf := *config |
||||
fullConf.SetDefaults() |
||||
s := &connection{ |
||||
sshConn: sshConn{conn: c}, |
||||
} |
||||
perms, err := s.serverHandshake(&fullConf) |
||||
if err != nil { |
||||
c.Close() |
||||
return nil, nil, nil, err |
||||
} |
||||
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil |
||||
} |
||||
|
||||
// signAndMarshal signs the data with the appropriate algorithm,
|
||||
// and serializes the result in SSH wire format.
|
||||
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { |
||||
sig, err := k.Sign(rand, data) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return Marshal(sig), nil |
||||
} |
||||
|
||||
// handshake performs key exchange and user authentication.
|
||||
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { |
||||
if len(config.hostKeys) == 0 { |
||||
return nil, errors.New("ssh: server has no host keys") |
||||
} |
||||
|
||||
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { |
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||
} |
||||
|
||||
if config.ServerVersion != "" { |
||||
s.serverVersion = []byte(config.ServerVersion) |
||||
} else { |
||||
s.serverVersion = []byte(packageVersion) |
||||
} |
||||
var err error |
||||
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) |
||||
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) |
||||
|
||||
if err := s.transport.requestKeyChange(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if packet, err := s.transport.readPacket(); err != nil { |
||||
return nil, err |
||||
} else if packet[0] != msgNewKeys { |
||||
return nil, unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
|
||||
// We just did the key change, so the session ID is established.
|
||||
s.sessionID = s.transport.getSessionID() |
||||
|
||||
var packet []byte |
||||
if packet, err = s.transport.readPacket(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var serviceRequest serviceRequestMsg |
||||
if err = Unmarshal(packet, &serviceRequest); err != nil { |
||||
return nil, err |
||||
} |
||||
if serviceRequest.Service != serviceUserAuth { |
||||
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") |
||||
} |
||||
serviceAccept := serviceAcceptMsg{ |
||||
Service: serviceUserAuth, |
||||
} |
||||
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
perms, err := s.serverAuthenticate(config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
s.mux = newMux(s.transport) |
||||
return perms, err |
||||
} |
||||
|
||||
func isAcceptableAlgo(algo string) bool { |
||||
switch algo { |
||||
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func checkSourceAddress(addr net.Addr, sourceAddr string) error { |
||||
if addr == nil { |
||||
return errors.New("ssh: no address known for client, but source-address match required") |
||||
} |
||||
|
||||
tcpAddr, ok := addr.(*net.TCPAddr) |
||||
if !ok { |
||||
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) |
||||
} |
||||
|
||||
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { |
||||
if bytes.Equal(allowedIP, tcpAddr.IP) { |
||||
return nil |
||||
} |
||||
} else { |
||||
_, ipNet, err := net.ParseCIDR(sourceAddr) |
||||
if err != nil { |
||||
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) |
||||
} |
||||
|
||||
if ipNet.Contains(tcpAddr.IP) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) |
||||
} |
||||
|
||||
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { |
||||
var err error |
||||
var cache pubKeyCache |
||||
var perms *Permissions |
||||
|
||||
userAuthLoop: |
||||
for { |
||||
var userAuthReq userAuthRequestMsg |
||||
if packet, err := s.transport.readPacket(); err != nil { |
||||
return nil, err |
||||
} else if err = Unmarshal(packet, &userAuthReq); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if userAuthReq.Service != serviceSSH { |
||||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) |
||||
} |
||||
|
||||
s.user = userAuthReq.User |
||||
perms = nil |
||||
authErr := errors.New("no auth passed yet") |
||||
|
||||
switch userAuthReq.Method { |
||||
case "none": |
||||
if config.NoClientAuth { |
||||
s.user = "" |
||||
authErr = nil |
||||
} |
||||
case "password": |
||||
if config.PasswordCallback == nil { |
||||
authErr = errors.New("ssh: password auth not configured") |
||||
break |
||||
} |
||||
payload := userAuthReq.Payload |
||||
if len(payload) < 1 || payload[0] != 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
payload = payload[1:] |
||||
password, payload, ok := parseString(payload) |
||||
if !ok || len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
perms, authErr = config.PasswordCallback(s, password) |
||||
case "keyboard-interactive": |
||||
if config.KeyboardInteractiveCallback == nil { |
||||
authErr = errors.New("ssh: keyboard-interactive auth not configubred") |
||||
break |
||||
} |
||||
|
||||
prompter := &sshClientKeyboardInteractive{s} |
||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) |
||||
case "publickey": |
||||
if config.PublicKeyCallback == nil { |
||||
authErr = errors.New("ssh: publickey auth not configured") |
||||
break |
||||
} |
||||
payload := userAuthReq.Payload |
||||
if len(payload) < 1 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
isQuery := payload[0] == 0 |
||||
payload = payload[1:] |
||||
algoBytes, payload, ok := parseString(payload) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
algo := string(algoBytes) |
||||
if !isAcceptableAlgo(algo) { |
||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) |
||||
break |
||||
} |
||||
|
||||
pubKeyData, payload, ok := parseString(payload) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
pubKey, err := ParsePublicKey(pubKeyData) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
candidate, ok := cache.get(s.user, pubKeyData) |
||||
if !ok { |
||||
candidate.user = s.user |
||||
candidate.pubKeyData = pubKeyData |
||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) |
||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { |
||||
candidate.result = checkSourceAddress( |
||||
s.RemoteAddr(), |
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption]) |
||||
} |
||||
cache.add(candidate) |
||||
} |
||||
|
||||
if isQuery { |
||||
// The client can query if the given public key
|
||||
// would be okay.
|
||||
if len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
|
||||
if candidate.result == nil { |
||||
okMsg := userAuthPubKeyOkMsg{ |
||||
Algo: algo, |
||||
PubKey: pubKeyData, |
||||
} |
||||
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { |
||||
return nil, err |
||||
} |
||||
continue userAuthLoop |
||||
} |
||||
authErr = candidate.result |
||||
} else { |
||||
sig, payload, ok := parseSignature(payload) |
||||
if !ok || len(payload) > 0 { |
||||
return nil, parseError(msgUserAuthRequest) |
||||
} |
||||
// Ensure the public key algo and signature algo
|
||||
// are supported. Compare the private key
|
||||
// algorithm name that corresponds to algo with
|
||||
// sig.Format. This is usually the same, but
|
||||
// for certs, the names differ.
|
||||
if !isAcceptableAlgo(sig.Format) { |
||||
break |
||||
} |
||||
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) |
||||
|
||||
if err := pubKey.Verify(signedData, sig); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
authErr = candidate.result |
||||
perms = candidate.perms |
||||
} |
||||
default: |
||||
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) |
||||
} |
||||
|
||||
if config.AuthLogCallback != nil { |
||||
config.AuthLogCallback(s, userAuthReq.Method, authErr) |
||||
} |
||||
|
||||
if authErr == nil { |
||||
break userAuthLoop |
||||
} |
||||
|
||||
var failureMsg userAuthFailureMsg |
||||
if config.PasswordCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "password") |
||||
} |
||||
if config.PublicKeyCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "publickey") |
||||
} |
||||
if config.KeyboardInteractiveCallback != nil { |
||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") |
||||
} |
||||
|
||||
if len(failureMsg.Methods) == 0 { |
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") |
||||
} |
||||
|
||||
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { |
||||
return nil, err |
||||
} |
||||
return perms, nil |
||||
} |
||||
|
||||
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
||||
// asking the client on the other side of a ServerConn.
|
||||
type sshClientKeyboardInteractive struct { |
||||
*connection |
||||
} |
||||
|
||||
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { |
||||
if len(questions) != len(echos) { |
||||
return nil, errors.New("ssh: echos and questions must have equal length") |
||||
} |
||||
|
||||
var prompts []byte |
||||
for i := range questions { |
||||
prompts = appendString(prompts, questions[i]) |
||||
prompts = appendBool(prompts, echos[i]) |
||||
} |
||||
|
||||
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ |
||||
Instruction: instruction, |
||||
NumPrompts: uint32(len(questions)), |
||||
Prompts: prompts, |
||||
})); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.transport.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if packet[0] != msgUserAuthInfoResponse { |
||||
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) |
||||
} |
||||
packet = packet[1:] |
||||
|
||||
n, packet, ok := parseUint32(packet) |
||||
if !ok || int(n) != len(questions) { |
||||
return nil, parseError(msgUserAuthInfoResponse) |
||||
} |
||||
|
||||
for i := uint32(0); i < n; i++ { |
||||
ans, rest, ok := parseString(packet) |
||||
if !ok { |
||||
return nil, parseError(msgUserAuthInfoResponse) |
||||
} |
||||
|
||||
answers = append(answers, string(ans)) |
||||
packet = rest |
||||
} |
||||
if len(packet) != 0 { |
||||
return nil, errors.New("ssh: junk at end of message") |
||||
} |
||||
|
||||
return answers, nil |
||||
} |
@ -1,605 +0,0 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Session implements an interactive session described in
|
||||
// "RFC 4254, section 6".
|
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"sync" |
||||
) |
||||
|
||||
type Signal string |
||||
|
||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
||||
const ( |
||||
SIGABRT Signal = "ABRT" |
||||
SIGALRM Signal = "ALRM" |
||||
SIGFPE Signal = "FPE" |
||||
SIGHUP Signal = "HUP" |
||||
SIGILL Signal = "ILL" |
||||
SIGINT Signal = "INT" |
||||
SIGKILL Signal = "KILL" |
||||
SIGPIPE Signal = "PIPE" |
||||
SIGQUIT Signal = "QUIT" |
||||
SIGSEGV Signal = "SEGV" |
||||
SIGTERM Signal = "TERM" |
||||
SIGUSR1 Signal = "USR1" |
||||
SIGUSR2 Signal = "USR2" |
||||
) |
||||
|
||||
var signals = map[Signal]int{ |
||||
SIGABRT: 6, |
||||
SIGALRM: 14, |
||||
SIGFPE: 8, |
||||
SIGHUP: 1, |
||||
SIGILL: 4, |
||||
SIGINT: 2, |
||||
SIGKILL: 9, |
||||
SIGPIPE: 13, |
||||
SIGQUIT: 3, |
||||
SIGSEGV: 11, |
||||
SIGTERM: 15, |
||||
} |
||||
|
||||
type TerminalModes map[uint8]uint32 |
||||
|
||||
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
||||
const ( |
||||
tty_OP_END = 0 |
||||
VINTR = 1 |
||||
VQUIT = 2 |
||||
VERASE = 3 |
||||
VKILL = 4 |
||||
VEOF = 5 |
||||
VEOL = 6 |
||||
VEOL2 = 7 |
||||
VSTART = 8 |
||||
VSTOP = 9 |
||||
VSUSP = 10 |
||||
VDSUSP = 11 |
||||
VREPRINT = 12 |
||||
VWERASE = 13 |
||||
VLNEXT = 14 |
||||
VFLUSH = 15 |
||||
VSWTCH = 16 |
||||
VSTATUS = 17 |
||||
VDISCARD = 18 |
||||
IGNPAR = 30 |
||||
PARMRK = 31 |
||||
INPCK = 32 |
||||
ISTRIP = 33 |
||||
INLCR = 34 |
||||
IGNCR = 35 |
||||
ICRNL = 36 |
||||
IUCLC = 37 |
||||
IXON = 38 |
||||
IXANY = 39 |
||||
IXOFF = 40 |
||||
IMAXBEL = 41 |
||||
ISIG = 50 |
||||
ICANON = 51 |
||||
XCASE = 52 |
||||
ECHO = 53 |
||||
ECHOE = 54 |
||||
ECHOK = 55 |
||||
ECHONL = 56 |
||||
NOFLSH = 57 |
||||
TOSTOP = 58 |
||||
IEXTEN = 59 |
||||
ECHOCTL = 60 |
||||
ECHOKE = 61 |
||||
PENDIN = 62 |
||||
OPOST = 70 |
||||
OLCUC = 71 |
||||
ONLCR = 72 |
||||
OCRNL = 73 |
||||
ONOCR = 74 |
||||
ONLRET = 75 |
||||
CS7 = 90 |
||||
CS8 = 91 |
||||
PARENB = 92 |
||||
PARODD = 93 |
||||
TTY_OP_ISPEED = 128 |
||||
TTY_OP_OSPEED = 129 |
||||
) |
||||
|
||||
// A Session represents a connection to a remote command or shell.
|
||||
type Session struct { |
||||
// Stdin specifies the remote process's standard input.
|
||||
// If Stdin is nil, the remote process reads from an empty
|
||||
// bytes.Buffer.
|
||||
Stdin io.Reader |
||||
|
||||
// Stdout and Stderr specify the remote process's standard
|
||||
// output and error.
|
||||
//
|
||||
// If either is nil, Run connects the corresponding file
|
||||
// descriptor to an instance of ioutil.Discard. There is a
|
||||
// fixed amount of buffering that is shared for the two streams.
|
||||
// If either blocks it may eventually cause the remote
|
||||
// command to block.
|
||||
Stdout io.Writer |
||||
Stderr io.Writer |
||||
|
||||
ch Channel // the channel backing this session
|
||||
started bool // true once Start, Run or Shell is invoked.
|
||||
copyFuncs []func() error |
||||
errors chan error // one send per copyFunc
|
||||
|
||||
// true if pipe method is active
|
||||
stdinpipe, stdoutpipe, stderrpipe bool |
||||
|
||||
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
||||
// and Stdin was specified by the user; it is the write end of
|
||||
// a pipe connecting Session.Stdin to the stdin channel.
|
||||
stdinPipeWriter io.WriteCloser |
||||
|
||||
exitStatus chan error |
||||
} |
||||
|
||||
// SendRequest sends an out-of-band channel request on the SSH channel
|
||||
// underlying the session.
|
||||
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||
return s.ch.SendRequest(name, wantReply, payload) |
||||
} |
||||
|
||||
func (s *Session) Close() error { |
||||
return s.ch.Close() |
||||
} |
||||
|
||||
// RFC 4254 Section 6.4.
|
||||
type setenvRequest struct { |
||||
Name string |
||||
Value string |
||||
} |
||||
|
||||
// Setenv sets an environment variable that will be applied to any
|
||||
// command executed by Shell or Run.
|
||||
func (s *Session) Setenv(name, value string) error { |
||||
msg := setenvRequest{ |
||||
Name: name, |
||||
Value: value, |
||||
} |
||||
ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: setenv failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.2.
|
||||
type ptyRequestMsg struct { |
||||
Term string |
||||
Columns uint32 |
||||
Rows uint32 |
||||
Width uint32 |
||||
Height uint32 |
||||
Modelist string |
||||
} |
||||
|
||||
// RequestPty requests the association of a pty with the session on the remote host.
|
||||
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { |
||||
var tm []byte |
||||
for k, v := range termmodes { |
||||
kv := struct { |
||||
Key byte |
||||
Val uint32 |
||||
}{k, v} |
||||
|
||||
tm = append(tm, Marshal(&kv)...) |
||||
} |
||||
tm = append(tm, tty_OP_END) |
||||
req := ptyRequestMsg{ |
||||
Term: term, |
||||
Columns: uint32(w), |
||||
Rows: uint32(h), |
||||
Width: uint32(w * 8), |
||||
Height: uint32(h * 8), |
||||
Modelist: string(tm), |
||||
} |
||||
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: pty-req failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type subsystemRequestMsg struct { |
||||
Subsystem string |
||||
} |
||||
|
||||
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
||||
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
||||
func (s *Session) RequestSubsystem(subsystem string) error { |
||||
msg := subsystemRequestMsg{ |
||||
Subsystem: subsystem, |
||||
} |
||||
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) |
||||
if err == nil && !ok { |
||||
err = errors.New("ssh: subsystem request failed") |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.9.
|
||||
type signalMsg struct { |
||||
Signal string |
||||
} |
||||
|
||||
// Signal sends the given signal to the remote process.
|
||||
// sig is one of the SIG* constants.
|
||||
func (s *Session) Signal(sig Signal) error { |
||||
msg := signalMsg{ |
||||
Signal: string(sig), |
||||
} |
||||
|
||||
_, err := s.ch.SendRequest("signal", false, Marshal(&msg)) |
||||
return err |
||||
} |
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type execMsg struct { |
||||
Command string |
||||
} |
||||
|
||||
// Start runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start or Shell.
|
||||
func (s *Session) Start(cmd string) error { |
||||
if s.started { |
||||
return errors.New("ssh: session already started") |
||||
} |
||||
req := execMsg{ |
||||
Command: cmd, |
||||
} |
||||
|
||||
ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) |
||||
if err == nil && !ok { |
||||
err = fmt.Errorf("ssh: command %v failed", cmd) |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.start() |
||||
} |
||||
|
||||
// Run runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start, Shell, Output,
|
||||
// or CombinedOutput.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the command fails to run or doesn't complete successfully, the
|
||||
// error is of type *ExitError. Other error types may be
|
||||
// returned for I/O problems.
|
||||
func (s *Session) Run(cmd string) error { |
||||
err := s.Start(cmd) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.Wait() |
||||
} |
||||
|
||||
// Output runs cmd on the remote host and returns its standard output.
|
||||
func (s *Session) Output(cmd string) ([]byte, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
var b bytes.Buffer |
||||
s.Stdout = &b |
||||
err := s.Run(cmd) |
||||
return b.Bytes(), err |
||||
} |
||||
|
||||
type singleWriter struct { |
||||
b bytes.Buffer |
||||
mu sync.Mutex |
||||
} |
||||
|
||||
func (w *singleWriter) Write(p []byte) (int, error) { |
||||
w.mu.Lock() |
||||
defer w.mu.Unlock() |
||||
return w.b.Write(p) |
||||
} |
||||
|
||||
// CombinedOutput runs cmd on the remote host and returns its combined
|
||||
// standard output and standard error.
|
||||
func (s *Session) CombinedOutput(cmd string) ([]byte, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
if s.Stderr != nil { |
||||
return nil, errors.New("ssh: Stderr already set") |
||||
} |
||||
var b singleWriter |
||||
s.Stdout = &b |
||||
s.Stderr = &b |
||||
err := s.Run(cmd) |
||||
return b.b.Bytes(), err |
||||
} |
||||
|
||||
// Shell starts a login shell on the remote host. A Session only
|
||||
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
||||
func (s *Session) Shell() error { |
||||
if s.started { |
||||
return errors.New("ssh: session already started") |
||||
} |
||||
|
||||
ok, err := s.ch.SendRequest("shell", true, nil) |
||||
if err == nil && !ok { |
||||
return fmt.Errorf("ssh: cound not start shell") |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return s.start() |
||||
} |
||||
|
||||
func (s *Session) start() error { |
||||
s.started = true |
||||
|
||||
type F func(*Session) |
||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { |
||||
setupFd(s) |
||||
} |
||||
|
||||
s.errors = make(chan error, len(s.copyFuncs)) |
||||
for _, fn := range s.copyFuncs { |
||||
go func(fn func() error) { |
||||
s.errors <- fn() |
||||
}(fn) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Wait waits for the remote command to exit.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the command fails to run or doesn't complete successfully, the
|
||||
// error is of type *ExitError. Other error types may be
|
||||
// returned for I/O problems.
|
||||
func (s *Session) Wait() error { |
||||
if !s.started { |
||||
return errors.New("ssh: session not started") |
||||
} |
||||
waitErr := <-s.exitStatus |
||||
|
||||
if s.stdinPipeWriter != nil { |
||||
s.stdinPipeWriter.Close() |
||||
} |
||||
var copyError error |
||||
for _ = range s.copyFuncs { |
||||
if err := <-s.errors; err != nil && copyError == nil { |
||||
copyError = err |
||||
} |
||||
} |
||||
if waitErr != nil { |
||||
return waitErr |
||||
} |
||||
return copyError |
||||
} |
||||
|
||||
func (s *Session) wait(reqs <-chan *Request) error { |
||||
wm := Waitmsg{status: -1} |
||||
// Wait for msg channel to be closed before returning.
|
||||
for msg := range reqs { |
||||
switch msg.Type { |
||||
case "exit-status": |
||||
d := msg.Payload |
||||
wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) |
||||
case "exit-signal": |
||||
var sigval struct { |
||||
Signal string |
||||
CoreDumped bool |
||||
Error string |
||||
Lang string |
||||
} |
||||
if err := Unmarshal(msg.Payload, &sigval); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Must sanitize strings?
|
||||
wm.signal = sigval.Signal |
||||
wm.msg = sigval.Error |
||||
wm.lang = sigval.Lang |
||||
default: |
||||
// This handles keepalives and matches
|
||||
// OpenSSH's behaviour.
|
||||
if msg.WantReply { |
||||
msg.Reply(false, nil) |
||||
} |
||||
} |
||||
} |
||||
if wm.status == 0 { |
||||
return nil |
||||
} |
||||
if wm.status == -1 { |
||||
// exit-status was never sent from server
|
||||
if wm.signal == "" { |
||||
return errors.New("wait: remote command exited without exit status or exit signal") |
||||
} |
||||
wm.status = 128 |
||||
if _, ok := signals[Signal(wm.signal)]; ok { |
||||
wm.status += signals[Signal(wm.signal)] |
||||
} |
||||
} |
||||
return &ExitError{wm} |
||||
} |
||||
|
||||
func (s *Session) stdin() { |
||||
if s.stdinpipe { |
||||
return |
||||
} |
||||
var stdin io.Reader |
||||
if s.Stdin == nil { |
||||
stdin = new(bytes.Buffer) |
||||
} else { |
||||
r, w := io.Pipe() |
||||
go func() { |
||||
_, err := io.Copy(w, s.Stdin) |
||||
w.CloseWithError(err) |
||||
}() |
||||
stdin, s.stdinPipeWriter = r, w |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.ch, stdin) |
||||
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { |
||||
err = err1 |
||||
} |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
func (s *Session) stdout() { |
||||
if s.stdoutpipe { |
||||
return |
||||
} |
||||
if s.Stdout == nil { |
||||
s.Stdout = ioutil.Discard |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.Stdout, s.ch) |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
func (s *Session) stderr() { |
||||
if s.stderrpipe { |
||||
return |
||||
} |
||||
if s.Stderr == nil { |
||||
s.Stderr = ioutil.Discard |
||||
} |
||||
s.copyFuncs = append(s.copyFuncs, func() error { |
||||
_, err := io.Copy(s.Stderr, s.ch.Stderr()) |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
// sessionStdin reroutes Close to CloseWrite.
|
||||
type sessionStdin struct { |
||||
io.Writer |
||||
ch Channel |
||||
} |
||||
|
||||
func (s *sessionStdin) Close() error { |
||||
return s.ch.CloseWrite() |
||||
} |
||||
|
||||
// StdinPipe returns a pipe that will be connected to the
|
||||
// remote command's standard input when the command starts.
|
||||
func (s *Session) StdinPipe() (io.WriteCloser, error) { |
||||
if s.Stdin != nil { |
||||
return nil, errors.New("ssh: Stdin already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StdinPipe after process started") |
||||
} |
||||
s.stdinpipe = true |
||||
return &sessionStdin{s.ch, s.ch}, nil |
||||
} |
||||
|
||||
// StdoutPipe returns a pipe that will be connected to the
|
||||
// remote command's standard output when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StdoutPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StdoutPipe() (io.Reader, error) { |
||||
if s.Stdout != nil { |
||||
return nil, errors.New("ssh: Stdout already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StdoutPipe after process started") |
||||
} |
||||
s.stdoutpipe = true |
||||
return s.ch, nil |
||||
} |
||||
|
||||
// StderrPipe returns a pipe that will be connected to the
|
||||
// remote command's standard error when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StderrPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StderrPipe() (io.Reader, error) { |
||||
if s.Stderr != nil { |
||||
return nil, errors.New("ssh: Stderr already set") |
||||
} |
||||
if s.started { |
||||
return nil, errors.New("ssh: StderrPipe after process started") |
||||
} |
||||
s.stderrpipe = true |
||||
return s.ch.Stderr(), nil |
||||
} |
||||
|
||||
// newSession returns a new interactive session on the remote host.
|
||||
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { |
||||
s := &Session{ |
||||
ch: ch, |
||||
} |
||||
s.exitStatus = make(chan error, 1) |
||||
go func() { |
||||
s.exitStatus <- s.wait(reqs) |
||||
}() |
||||
|
||||
return s, nil |
||||
} |
||||
|
||||
// An ExitError reports unsuccessful completion of a remote command.
|
||||
type ExitError struct { |
||||
Waitmsg |
||||
} |
||||
|
||||
func (e *ExitError) Error() string { |
||||
return e.Waitmsg.String() |
||||
} |
||||
|
||||
// Waitmsg stores the information about an exited remote command
|
||||
// as reported by Wait.
|
||||
type Waitmsg struct { |
||||
status int |
||||
signal string |
||||
msg string |
||||
lang string |
||||
} |
||||
|
||||
// ExitStatus returns the exit status of the remote command.
|
||||
func (w Waitmsg) ExitStatus() int { |
||||
return w.status |
||||
} |
||||
|
||||
// Signal returns the exit signal of the remote command if
|
||||
// it was terminated violently.
|
||||
func (w Waitmsg) Signal() string { |
||||
return w.signal |
||||
} |
||||
|
||||
// Msg returns the exit message given by the remote command
|
||||
func (w Waitmsg) Msg() string { |
||||
return w.msg |
||||
} |
||||
|
||||
// Lang returns the language tag. See RFC 3066
|
||||
func (w Waitmsg) Lang() string { |
||||
return w.lang |
||||
} |
||||
|
||||
func (w Waitmsg) String() string { |
||||
return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal) |
||||
} |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue