You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							330 lines
						
					
					
						
							7.2 KiB
						
					
					
				
			
		
		
	
	
							330 lines
						
					
					
						
							7.2 KiB
						
					
					
				| // Copyright 2013 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package ssh
 | |
| 
 | |
| import (
 | |
| 	"encoding/binary"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"log"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| )
 | |
| 
 | |
| // debugMux, if set, causes messages in the connection protocol to be
 | |
| // logged.
 | |
| const debugMux = false
 | |
| 
 | |
| // chanList is a thread safe channel list.
 | |
| type chanList struct {
 | |
| 	// protects concurrent access to chans
 | |
| 	sync.Mutex
 | |
| 
 | |
| 	// chans are indexed by the local id of the channel, which the
 | |
| 	// other side should send in the PeersId field.
 | |
| 	chans []*channel
 | |
| 
 | |
| 	// This is a debugging aid: it offsets all IDs by this
 | |
| 	// amount. This helps distinguish otherwise identical
 | |
| 	// server/client muxes
 | |
| 	offset uint32
 | |
| }
 | |
| 
 | |
| // Assigns a channel ID to the given channel.
 | |
| func (c *chanList) add(ch *channel) uint32 {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	for i := range c.chans {
 | |
| 		if c.chans[i] == nil {
 | |
| 			c.chans[i] = ch
 | |
| 			return uint32(i) + c.offset
 | |
| 		}
 | |
| 	}
 | |
| 	c.chans = append(c.chans, ch)
 | |
| 	return uint32(len(c.chans)-1) + c.offset
 | |
| }
 | |
| 
 | |
| // getChan returns the channel for the given ID.
 | |
| func (c *chanList) getChan(id uint32) *channel {
 | |
| 	id -= c.offset
 | |
| 
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	if id < uint32(len(c.chans)) {
 | |
| 		return c.chans[id]
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *chanList) remove(id uint32) {
 | |
| 	id -= c.offset
 | |
| 	c.Lock()
 | |
| 	if id < uint32(len(c.chans)) {
 | |
| 		c.chans[id] = nil
 | |
| 	}
 | |
| 	c.Unlock()
 | |
| }
 | |
| 
 | |
| // dropAll forgets all channels it knows, returning them in a slice.
 | |
| func (c *chanList) dropAll() []*channel {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	var r []*channel
 | |
| 
 | |
| 	for _, ch := range c.chans {
 | |
| 		if ch == nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		r = append(r, ch)
 | |
| 	}
 | |
| 	c.chans = nil
 | |
| 	return r
 | |
| }
 | |
| 
 | |
| // mux represents the state for the SSH connection protocol, which
 | |
| // multiplexes many channels onto a single packet transport.
 | |
| type mux struct {
 | |
| 	conn     packetConn
 | |
| 	chanList chanList
 | |
| 
 | |
| 	incomingChannels chan NewChannel
 | |
| 
 | |
| 	globalSentMu     sync.Mutex
 | |
| 	globalResponses  chan interface{}
 | |
| 	incomingRequests chan *Request
 | |
| 
 | |
| 	errCond *sync.Cond
 | |
| 	err     error
 | |
| }
 | |
| 
 | |
| // When debugging, each new chanList instantiation has a different
 | |
| // offset.
 | |
| var globalOff uint32
 | |
| 
 | |
| func (m *mux) Wait() error {
 | |
| 	m.errCond.L.Lock()
 | |
| 	defer m.errCond.L.Unlock()
 | |
| 	for m.err == nil {
 | |
| 		m.errCond.Wait()
 | |
| 	}
 | |
| 	return m.err
 | |
| }
 | |
| 
 | |
| // newMux returns a mux that runs over the given connection.
 | |
| func newMux(p packetConn) *mux {
 | |
| 	m := &mux{
 | |
| 		conn:             p,
 | |
| 		incomingChannels: make(chan NewChannel, 16),
 | |
| 		globalResponses:  make(chan interface{}, 1),
 | |
| 		incomingRequests: make(chan *Request, 16),
 | |
| 		errCond:          newCond(),
 | |
| 	}
 | |
| 	if debugMux {
 | |
| 		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
 | |
| 	}
 | |
| 
 | |
| 	go m.loop()
 | |
| 	return m
 | |
| }
 | |
| 
 | |
| func (m *mux) sendMessage(msg interface{}) error {
 | |
| 	p := Marshal(msg)
 | |
| 	if debugMux {
 | |
| 		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
 | |
| 	}
 | |
| 	return m.conn.writePacket(p)
 | |
| }
 | |
| 
 | |
| func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
 | |
| 	if wantReply {
 | |
| 		m.globalSentMu.Lock()
 | |
| 		defer m.globalSentMu.Unlock()
 | |
| 	}
 | |
| 
 | |
| 	if err := m.sendMessage(globalRequestMsg{
 | |
| 		Type:      name,
 | |
| 		WantReply: wantReply,
 | |
| 		Data:      payload,
 | |
| 	}); err != nil {
 | |
| 		return false, nil, err
 | |
| 	}
 | |
| 
 | |
| 	if !wantReply {
 | |
| 		return false, nil, nil
 | |
| 	}
 | |
| 
 | |
| 	msg, ok := <-m.globalResponses
 | |
| 	if !ok {
 | |
| 		return false, nil, io.EOF
 | |
| 	}
 | |
| 	switch msg := msg.(type) {
 | |
| 	case *globalRequestFailureMsg:
 | |
| 		return false, msg.Data, nil
 | |
| 	case *globalRequestSuccessMsg:
 | |
| 		return true, msg.Data, nil
 | |
| 	default:
 | |
| 		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ackRequest must be called after processing a global request that
 | |
| // has WantReply set.
 | |
| func (m *mux) ackRequest(ok bool, data []byte) error {
 | |
| 	if ok {
 | |
| 		return m.sendMessage(globalRequestSuccessMsg{Data: data})
 | |
| 	}
 | |
| 	return m.sendMessage(globalRequestFailureMsg{Data: data})
 | |
| }
 | |
| 
 | |
| func (m *mux) Close() error {
 | |
| 	return m.conn.Close()
 | |
| }
 | |
| 
 | |
| // loop runs the connection machine. It will process packets until an
 | |
| // error is encountered. To synchronize on loop exit, use mux.Wait.
 | |
| func (m *mux) loop() {
 | |
| 	var err error
 | |
| 	for err == nil {
 | |
| 		err = m.onePacket()
 | |
| 	}
 | |
| 
 | |
| 	for _, ch := range m.chanList.dropAll() {
 | |
| 		ch.close()
 | |
| 	}
 | |
| 
 | |
| 	close(m.incomingChannels)
 | |
| 	close(m.incomingRequests)
 | |
| 	close(m.globalResponses)
 | |
| 
 | |
| 	m.conn.Close()
 | |
| 
 | |
| 	m.errCond.L.Lock()
 | |
| 	m.err = err
 | |
| 	m.errCond.Broadcast()
 | |
| 	m.errCond.L.Unlock()
 | |
| 
 | |
| 	if debugMux {
 | |
| 		log.Println("loop exit", err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // onePacket reads and processes one packet.
 | |
| func (m *mux) onePacket() error {
 | |
| 	packet, err := m.conn.readPacket()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if debugMux {
 | |
| 		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
 | |
| 			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
 | |
| 		} else {
 | |
| 			p, _ := decode(packet)
 | |
| 			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch packet[0] {
 | |
| 	case msgChannelOpen:
 | |
| 		return m.handleChannelOpen(packet)
 | |
| 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
 | |
| 		return m.handleGlobalPacket(packet)
 | |
| 	}
 | |
| 
 | |
| 	// assume a channel packet.
 | |
| 	if len(packet) < 5 {
 | |
| 		return parseError(packet[0])
 | |
| 	}
 | |
| 	id := binary.BigEndian.Uint32(packet[1:])
 | |
| 	ch := m.chanList.getChan(id)
 | |
| 	if ch == nil {
 | |
| 		return fmt.Errorf("ssh: invalid channel %d", id)
 | |
| 	}
 | |
| 
 | |
| 	return ch.handlePacket(packet)
 | |
| }
 | |
| 
 | |
| func (m *mux) handleGlobalPacket(packet []byte) error {
 | |
| 	msg, err := decode(packet)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	switch msg := msg.(type) {
 | |
| 	case *globalRequestMsg:
 | |
| 		m.incomingRequests <- &Request{
 | |
| 			Type:      msg.Type,
 | |
| 			WantReply: msg.WantReply,
 | |
| 			Payload:   msg.Data,
 | |
| 			mux:       m,
 | |
| 		}
 | |
| 	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 | |
| 		m.globalResponses <- msg
 | |
| 	default:
 | |
| 		panic(fmt.Sprintf("not a global message %#v", msg))
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // handleChannelOpen schedules a channel to be Accept()ed.
 | |
| func (m *mux) handleChannelOpen(packet []byte) error {
 | |
| 	var msg channelOpenMsg
 | |
| 	if err := Unmarshal(packet, &msg); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
 | |
| 		failMsg := channelOpenFailureMsg{
 | |
| 			PeersId:  msg.PeersId,
 | |
| 			Reason:   ConnectionFailed,
 | |
| 			Message:  "invalid request",
 | |
| 			Language: "en_US.UTF-8",
 | |
| 		}
 | |
| 		return m.sendMessage(failMsg)
 | |
| 	}
 | |
| 
 | |
| 	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
 | |
| 	c.remoteId = msg.PeersId
 | |
| 	c.maxRemotePayload = msg.MaxPacketSize
 | |
| 	c.remoteWin.add(msg.PeersWindow)
 | |
| 	m.incomingChannels <- c
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
 | |
| 	ch, err := m.openChannel(chanType, extra)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	return ch, ch.incomingRequests, nil
 | |
| }
 | |
| 
 | |
| func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
 | |
| 	ch := m.newChannel(chanType, channelOutbound, extra)
 | |
| 
 | |
| 	ch.maxIncomingPayload = channelMaxPacket
 | |
| 
 | |
| 	open := channelOpenMsg{
 | |
| 		ChanType:         chanType,
 | |
| 		PeersWindow:      ch.myWindow,
 | |
| 		MaxPacketSize:    ch.maxIncomingPayload,
 | |
| 		TypeSpecificData: extra,
 | |
| 		PeersId:          ch.localId,
 | |
| 	}
 | |
| 	if err := m.sendMessage(open); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	switch msg := (<-ch.msg).(type) {
 | |
| 	case *channelOpenConfirmMsg:
 | |
| 		return ch, nil
 | |
| 	case *channelOpenFailureMsg:
 | |
| 		return nil, &OpenChannelError{msg.Reason, msg.Message}
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
 | |
| 	}
 | |
| }
 | |
| 
 |