kubernetes stream_test 源码

  • 2022-09-18
  • 浏览 (225)

kubernetes stream_test 代码

文件路径:/staging/src/k8s.io/apiserver/pkg/util/wsstream/stream_test.go

/*
Copyright 2015 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package wsstream

import (
	"bytes"
	"encoding/base64"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"reflect"
	"strings"
	"testing"
	"time"

	"golang.org/x/net/websocket"
)

func TestStream(t *testing.T) {
	input := "some random text"
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
	r.SetIdleTimeout(time.Second)
	data, err := readWebSocket(r, t, nil)
	if !reflect.DeepEqual(data, []byte(input)) {
		t.Errorf("unexpected server read: %v", data)
	}
	if err != nil {
		t.Fatal(err)
	}
}

func TestStreamPing(t *testing.T) {
	input := "some random text"
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
	r.SetIdleTimeout(time.Second)
	err := expectWebSocketFrames(r, t, nil, [][]byte{
		{},
		[]byte(input),
	})
	if err != nil {
		t.Fatal(err)
	}
}

func TestStreamBase64(t *testing.T) {
	input := "some random text"
	encoded := base64.StdEncoding.EncodeToString([]byte(input))
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
	data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
	if !reflect.DeepEqual(data, []byte(encoded)) {
		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
	}
	if err != nil {
		t.Fatal(err)
	}
}

func TestStreamVersionedBase64(t *testing.T) {
	input := "some random text"
	encoded := base64.StdEncoding.EncodeToString([]byte(input))
	r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
		"":                        {Binary: true},
		"binary.k8s.io":           {Binary: true},
		"base64.binary.k8s.io":    {Binary: false},
		"v1.binary.k8s.io":        {Binary: true},
		"v1.base64.binary.k8s.io": {Binary: false},
		"v2.binary.k8s.io":        {Binary: true},
		"v2.base64.binary.k8s.io": {Binary: false},
	})
	data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
	if !reflect.DeepEqual(data, []byte(encoded)) {
		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
	}
	if err != nil {
		t.Fatal(err)
	}
}

func TestStreamVersionedCopy(t *testing.T) {
	for i, test := range versionTests() {
		func() {
			supportedProtocols := map[string]ReaderProtocolConfig{}
			for p, binary := range test.supported {
				supportedProtocols[p] = ReaderProtocolConfig{
					Binary: binary,
				}
			}
			input := "some random text"
			r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
				err := r.Copy(w, req)
				if err != nil {
					w.WriteHeader(503)
				}
			}))
			defer s.Close()

			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
			if err != nil {
				t.Error(err)
				return
			}
			config.Protocol = test.requested
			client, err := websocket.DialConfig(config)
			if err != nil {
				if !test.error {
					t.Errorf("test %d: didn't expect error: %v", i, err)
				}
				return
			}
			defer client.Close()
			if test.error && err == nil {
				t.Errorf("test %d: expected an error", i)
				return
			}

			<-r.err
			if got, expected := r.selectedProtocol, test.expected; got != expected {
				t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
			}
		}()
	}
}

func TestStreamError(t *testing.T) {
	input := "some random text"
	errs := &errorReader{
		reads: [][]byte{
			[]byte("some random"),
			[]byte(" text"),
		},
		err: fmt.Errorf("bad read"),
	}
	r := NewReader(errs, false, NewDefaultReaderProtocols())

	data, err := readWebSocket(r, t, nil)
	if !reflect.DeepEqual(data, []byte(input)) {
		t.Errorf("unexpected server read: %v", data)
	}
	if err == nil || err.Error() != "bad read" {
		t.Fatal(err)
	}
}

func TestStreamSurvivesPanic(t *testing.T) {
	input := "some random text"
	errs := &errorReader{
		reads: [][]byte{
			[]byte("some random"),
			[]byte(" text"),
		},
		panicMessage: "bad read",
	}
	r := NewReader(errs, false, NewDefaultReaderProtocols())

	// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
	r.handleCrash = func(additionalHandlers ...func(interface{})) { recover() }

	data, err := readWebSocket(r, t, nil)
	if !reflect.DeepEqual(data, []byte(input)) {
		t.Errorf("unexpected server read: %v", data)
	}
	if err != nil {
		t.Fatal(err)
	}
}

func TestStreamClosedDuringRead(t *testing.T) {
	for i := 0; i < 25; i++ {
		ch := make(chan struct{})
		input := "some random text"
		errs := &errorReader{
			reads: [][]byte{
				[]byte("some random"),
				[]byte(" text"),
			},
			err:   fmt.Errorf("stuff"),
			pause: ch,
		}
		r := NewReader(errs, false, NewDefaultReaderProtocols())

		data, err := readWebSocket(r, t, func(c *websocket.Conn) {
			c.Close()
			close(ch)
		})
		// verify that the data returned by the server on an early close always has a specific error
		if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
			t.Fatal(err)
		}
		// verify that the data returned is a strict subset of the input
		if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
			t.Fatalf("unexpected server read: %q", string(data))
		}
	}
}

type errorReader struct {
	reads        [][]byte
	err          error
	panicMessage string
	pause        chan struct{}
}

func (r *errorReader) Read(p []byte) (int, error) {
	if len(r.reads) == 0 {
		if r.pause != nil {
			<-r.pause
		}
		if len(r.panicMessage) != 0 {
			panic(r.panicMessage)
		}
		return 0, r.err
	}
	next := r.reads[0]
	r.reads = r.reads[1:]
	copy(p, next)
	return len(next), nil
}

func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
	errCh := make(chan error, 1)
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		errCh <- r.Copy(w, req)
	}))
	defer s.Close()

	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
	config.Protocol = protocols
	client, err := websocket.DialConfig(config)
	if err != nil {
		return nil, err
	}
	defer client.Close()

	if fn != nil {
		fn(client)
	}

	data, err := ioutil.ReadAll(client)
	if err != nil {
		return data, err
	}
	return data, <-errCh
}

func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
	errCh := make(chan error, 1)
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		errCh <- r.Copy(w, req)
	}))
	defer s.Close()

	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
	config.Protocol = protocols
	ws, err := websocket.DialConfig(config)
	if err != nil {
		return err
	}
	defer ws.Close()

	if fn != nil {
		fn(ws)
	}

	for i := range frames {
		var data []byte
		if err := websocket.Message.Receive(ws, &data); err != nil {
			return err
		}
		if !reflect.DeepEqual(frames[i], data) {
			return fmt.Errorf("frame %d did not match expected: %v", data, err)
		}
	}
	var data []byte
	if err := websocket.Message.Receive(ws, &data); err != io.EOF {
		return fmt.Errorf("expected no more frames: %v (%v)", err, data)
	}
	return <-errCh
}

相关信息

kubernetes 源码目录

相关文章

kubernetes conn 源码

kubernetes conn_test 源码

kubernetes doc 源码

kubernetes stream 源码

0  赞