| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // Copyright 2019 Yunion
- // Copyright 2016 Google Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // Package tftp implements a read-only TFTP server.
- package tftp
- import (
- "bytes"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "strconv"
- "time"
- )
- const (
- // DefaultWriteTimeout is the duration a client has to acknowledge
- // a data packet from the server. This can be overridden by
- // setting Server.WriteTimeout.
- DefaultWriteTimeout = 2 * time.Second
- // DefaultWriteAttempts is the maximum number of times a single
- // packet will be (re)sent before timing out a client. This can be
- // overridden by setting Server.WriteAttempts.
- DefaultWriteAttempts = 5
- // DefaultBlockSize is the maximum block size used to send data to
- // clients. The server will respect a request for a smaller block
- // size, but requests for larger block sizes will be clamped to
- // DefaultBlockSize. This can be overridden by setting
- // Server.MaxBlockSize.
- DefaultBlockSize = 1450
- // maxErrorSize is the largest error message string that will be
- // sent to the client without truncation.
- maxErrorSize = 500
- )
- // A Handler provides bytes for a file.
- //
- // If size is non-zero, it must be equal to the number of bytes in
- // file. The server will offer the "tsize" extension to clients that
- // request it.
- //
- // Note that some clients (particularly firmware TFTP clients) can be
- // very capricious about servers not supporting all the options that
- // they request, so passing a size of 0 may cause TFTP transfers to
- // fail for some clients.
- type Handler func(path string, clientAddr net.Addr) (file io.ReadCloser, size int64, err error)
- // A Server defines parameters for running a TFTP server.
- type Server struct {
- Handler Handler // handler to invoke for requests
- // WriteTimeout sets the duration to wait for the client to
- // acknowledge a data packet. Defaults to DefaultWriteTimeout.
- WriteTimeout time.Duration
- // WriteAttempts sets how many times a packet will be (re)sent
- // before timing out the client and aborting the transfer. If 0,
- // uses DefaultWriteAttempts.
- WriteAttempts int
- // MaxBlockSize sets the maximum block size used for file
- // transfers. If 0, uses DefaultBlockSize.
- MaxBlockSize int64
- // InfoLog specifies an optional logger for informational
- // messages. If nil, informational messages are suppressed.
- InfoLog func(msg string)
- // TransferLog specifies an optional logger for completed
- // transfers. A successful transfer is logged with err == nil. If
- // nil, transfer logs are suppressed.
- TransferLog func(clientAddr net.Addr, path string, err error)
- // Dial specifies a function to use when setting up a "connected"
- // UDP socket to a TFTP client. While this is mostly here for
- // testing, it can also be used to implement advanced relay
- // functionality (e.g. serving TFTP through SOCKS). If nil,
- // net.Dial is used.
- Dial func(network, addr string) (net.Conn, error)
- }
- // ListenAndServe listens on the UDP network address addr and then
- // calls Serve to handle TFTP requests. If addr is blank, ":69" is
- // used.
- func (s *Server) ListenAndServe(addr string) error {
- if addr == "" {
- addr = ":69"
- }
- l, err := net.ListenPacket("udp", addr)
- if err != nil {
- return err
- }
- defer l.Close()
- s.infoLog("TFTP listening on %s", l.LocalAddr())
- return s.Serve(l)
- }
- // Serve accepts requests on listener l, creating a new transfer
- // goroutine for each. The transfer goroutines use s.Handler to get
- // bytes, and transfers them to the client.
- func (s *Server) Serve(l net.PacketConn) error {
- if s.Handler == nil {
- return errors.New("can't serve, Handler is nil")
- }
- if err := l.SetDeadline(time.Time{}); err != nil {
- return err
- }
- buf := make([]byte, 2048)
- for {
- n, addr, err := l.ReadFrom(buf)
- if err != nil {
- return err
- }
- req, err := parseRRQ(buf[:n])
- if err != nil {
- s.infoLog("bad request from %q: %s", addr, err)
- continue
- }
- go s.transferAndLog(addr, req)
- }
- }
- func (s *Server) infoLog(msg string, args ...interface{}) {
- if s.InfoLog != nil {
- s.InfoLog(fmt.Sprintf(msg, args...))
- }
- }
- func (s *Server) transferLog(addr net.Addr, path string, err error) {
- if s.TransferLog != nil {
- s.TransferLog(addr, path, err)
- }
- }
- func (s *Server) transferAndLog(addr net.Addr, req *rrq) {
- err := s.transfer(addr, req)
- if err != nil {
- err = fmt.Errorf("%q: %s", addr, err)
- }
- s.transferLog(addr, req.Filename, err)
- }
- func (s *Server) transfer(addr net.Addr, req *rrq) error {
- d := s.Dial
- if d == nil {
- d = net.Dial
- }
- conn, err := d("udp", addr.String())
- if err != nil {
- return fmt.Errorf("creating socket: %s", err)
- }
- defer conn.Close()
- file, size, err := s.Handler(req.Filename, addr)
- if err != nil {
- conn.Write(tftpError("failed to get file"))
- return fmt.Errorf("getting file bytes: %s", err)
- }
- defer file.Close()
- var b bytes.Buffer
- if req.BlockSize != 0 || (req.WantSize && size != 0) {
- // Client requested options, need to OACK them before sending
- // data.
- b.WriteByte(0)
- b.WriteByte(6)
- if req.BlockSize != 0 {
- maxBlockSize := s.MaxBlockSize
- if maxBlockSize <= 0 {
- maxBlockSize = DefaultBlockSize
- }
- if req.BlockSize > maxBlockSize {
- s.infoLog("clamping blocksize to %q: %d -> %d", addr, req.BlockSize, maxBlockSize)
- req.BlockSize = maxBlockSize
- }
- b.WriteString("blksize")
- b.WriteByte(0)
- b.WriteString(strconv.FormatInt(req.BlockSize, 10))
- b.WriteByte(0)
- }
- if req.WantSize && size != 0 {
- b.WriteString("tsize")
- b.WriteByte(0)
- b.WriteString(strconv.FormatInt(size, 10))
- b.WriteByte(0)
- }
- if err := s.send(conn, b.Bytes(), 0); err != nil {
- return fmt.Errorf("sending OACK: %s", err)
- }
- b.Reset()
- }
- if req.BlockSize == 0 {
- // Client didn't negotiate, use classic blocksize from RFC.
- req.BlockSize = 512
- }
- seq := uint16(1)
- b.Grow(int(req.BlockSize + 4))
- b.WriteByte(0)
- b.WriteByte(3)
- for {
- b.Truncate(2)
- if err = binary.Write(&b, binary.BigEndian, seq); err != nil {
- conn.Write(tftpError("internal server error"))
- return fmt.Errorf("writing seqnum: %s", err)
- }
- n, err := io.CopyN(&b, file, req.BlockSize)
- if err != nil && err != io.EOF {
- conn.Write(tftpError("internal server error"))
- return fmt.Errorf("reading bytes for block %d: %s", seq, err)
- }
- if err = s.send(conn, b.Bytes(), seq); err != nil {
- conn.Write(tftpError("timeout"))
- return fmt.Errorf("sending data packet %d: %s", seq, err)
- }
- seq++
- if n < req.BlockSize {
- // Transfer complete
- return nil
- }
- }
- }
- func (s *Server) send(conn net.Conn, b []byte, seq uint16) error {
- timeout := s.WriteTimeout
- if timeout <= 0 {
- timeout = DefaultWriteTimeout
- }
- attempts := s.WriteAttempts
- if attempts <= 0 {
- attempts = DefaultWriteAttempts
- }
- Attempt:
- for attempt := 0; attempt < attempts; attempt++ {
- if _, err := conn.Write(b); err != nil {
- return err
- }
- conn.SetReadDeadline(time.Now().Add(timeout))
- var recv [256]byte
- for {
- n, err := conn.Read(recv[:])
- if err != nil {
- if t, ok := err.(net.Error); ok && t.Timeout() {
- continue Attempt
- }
- return err
- }
- if n < 4 { // packet too small
- continue
- }
- switch binary.BigEndian.Uint16(recv[:2]) {
- case 4:
- if binary.BigEndian.Uint16(recv[2:4]) == seq {
- return nil
- }
- case 5:
- msg, _, _ := tftpStr(recv[4:])
- return fmt.Errorf("client aborted transfer: %s", msg)
- }
- }
- }
- return errors.New("timeout waiting for ACK")
- }
- type rrq struct {
- Filename string
- BlockSize int64
- WantSize bool
- }
- func parseRRQ(bs []byte) (*rrq, error) {
- // Smallest a useful TFTP packet can be is 6 bytes: 2b opcode, 1b
- // filename, 1b null, 1b mode, 1b null.
- if len(bs) < 6 || binary.BigEndian.Uint16(bs[:2]) != 1 {
- return nil, errors.New("not an RRQ packet")
- }
- fname, bs, err := tftpStr(bs[2:])
- if err != nil {
- return nil, fmt.Errorf("reading filename: %s", err)
- }
- mode, bs, err := tftpStr(bs)
- if err != nil {
- return nil, fmt.Errorf("reading mode: %s", err)
- }
- if mode != "octet" {
- // Only support octet mode, because in practice that's the
- // only remaining sensible use of TFTP (i.e. PXE booting)
- return nil, fmt.Errorf("unsupported transfer mode %q", mode)
- }
- req := &rrq{
- Filename: fname,
- }
- for len(bs) > 0 {
- opt, rest, err := tftpStr(bs)
- if err != nil {
- return nil, fmt.Errorf("reading option name: %s", err)
- }
- bs = rest
- val, rest, err := tftpStr(bs)
- if err != nil {
- return nil, fmt.Errorf("reading option %q value: %s", opt, err)
- }
- bs = rest
- if opt != "blksize" {
- if opt == "tsize" {
- req.WantSize = true
- }
- continue
- }
- size, err := strconv.ParseInt(val, 10, 64)
- if err != nil {
- return nil, fmt.Errorf("non-integer block size value %q", val)
- }
- if size < 8 || size > 65464 {
- return nil, fmt.Errorf("unsupported block size %q", size)
- }
- req.BlockSize = size
- }
- return req, nil
- }
- // tftpError constructs an ERROR packet.
- //
- // The error is coerced to the sensible subset of "netascii", namely
- // the printable ASCII characters plus newline.
- func tftpError(msg string) []byte {
- if len(msg) > maxErrorSize {
- msg = msg[:maxErrorSize]
- }
- var ret bytes.Buffer
- ret.Grow(len(msg) + 5)
- ret.Write([]byte{0, 5, 0, 0}) // generic "see message" error packet
- for _, b := range msg {
- switch {
- case b >= 0x20 && b <= 0x7E:
- ret.WriteRune(b)
- case b == '\r':
- // Assume this is the start of a CRLF sequence and just
- // swallow the CR. The LF will output CRLF, see
- // below. Also, please stop using CRLF line termination in
- // Go.
- case b == '\n':
- ret.WriteString("\r\n")
- default:
- ret.WriteByte('?')
- }
- }
- ret.WriteByte(0)
- return ret.Bytes()
- }
- // tftpStr extracts a null-terminated string from the given bytes, and
- // returns any remaining bytes.
- //
- // String content is checked to be a "read-useful" subset of
- // "netascii", itself a subset of ASCII. Specifically, all byte values
- // must fall in the range 0x20 to 0x7E inclusive.
- func tftpStr(bs []byte) (str string, remaining []byte, err error) {
- for i, b := range bs {
- if b == 0 {
- return string(bs[:i]), bs[i+1:], nil
- } else if b < 0x20 || b > 0x7E {
- return "", nil, fmt.Errorf("invalid netascii byte %q at offset %d", b, i)
- }
- }
- return "", nil, errors.New("no null terminated string found")
- }
|