kubernetes conn_test 源码

  • 2022-09-18
  • 浏览 (204)

kubernetes conn_test 代码

文件路径:/staging/src/k8s.io/apiserver/pkg/util/wsstream/conn_test.go

/*
Copyright 2015 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package wsstream

import (
	"encoding/base64"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"reflect"
	"sync"
	"testing"

	"golang.org/x/net/websocket"
)

func newServer(handler http.Handler) (*httptest.Server, string) {
	server := httptest.NewServer(handler)
	serverAddr := server.Listener.Addr().String()
	return server, serverAddr
}

func TestRawConn(t *testing.T) {
	channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
	conn := NewConn(NewDefaultChannelProtocols(channels))

	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		conn.Open(w, req)
	}))
	defer s.Close()

	client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
	if err != nil {
		t.Fatal(err)
	}
	defer client.Close()

	<-conn.ready
	wg := sync.WaitGroup{}

	// verify we can read a client write
	wg.Add(1)
	go func() {
		defer wg.Done()
		data, err := ioutil.ReadAll(conn.channels[0])
		if err != nil {
			t.Fatal(err)
		}
		if !reflect.DeepEqual(data, []byte("client")) {
			t.Errorf("unexpected server read: %v", data)
		}
	}()

	if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
		t.Fatalf("%d: %v", n, err)
	}

	// verify we can read a server write
	wg.Add(1)
	go func() {
		defer wg.Done()
		if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
			t.Fatalf("%d: %v", n, err)
		}
	}()

	data := make([]byte, 1024)
	if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
		t.Fatalf("%d: %v", n, err)
	}
	if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
		t.Errorf("unexpected client read: %v", data[:7])
	}

	// verify that an ignore channel is empty in both directions.
	if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
		t.Errorf("writes should be ignored")
	}
	data = make([]byte, 1024)
	if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
		t.Errorf("reads should be ignored")
	}

	// verify that a write to a Read channel doesn't block
	if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
		t.Errorf("writes should be ignored")
	}

	// verify that a read from a Write channel doesn't block
	data = make([]byte, 1024)
	if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
		t.Errorf("reads should be ignored")
	}

	// verify that a client write to a Write channel doesn't block (is dropped)
	if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
		t.Fatalf("%d: %v", n, err)
	}

	client.Close()
	wg.Wait()
}

func TestBase64Conn(t *testing.T) {
	conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		conn.Open(w, req)
	}))
	defer s.Close()

	config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
	if err != nil {
		t.Fatal(err)
	}
	config.Protocol = []string{"base64.channel.k8s.io"}
	client, err := websocket.DialConfig(config)
	if err != nil {
		t.Fatal(err)
	}
	defer client.Close()

	<-conn.ready
	wg := sync.WaitGroup{}
	wg.Add(1)
	go func() {
		defer wg.Done()
		data, err := ioutil.ReadAll(conn.channels[0])
		if err != nil {
			t.Fatal(err)
		}
		if !reflect.DeepEqual(data, []byte("client")) {
			t.Errorf("unexpected server read: %s", string(data))
		}
	}()

	clientData := base64.StdEncoding.EncodeToString([]byte("client"))
	if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
		t.Fatalf("%d: %v", n, err)
	}

	wg.Add(1)
	go func() {
		defer wg.Done()
		if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
			t.Fatalf("%d: %v", n, err)
		}
	}()

	data := make([]byte, 1024)
	if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
		t.Fatalf("%d: %v", n, err)
	}
	expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))

	if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
		t.Errorf("unexpected client read: %v", data[:9])
	}

	client.Close()
	wg.Wait()
}

type versionTest struct {
	supported map[string]bool // protocol -> binary
	requested []string
	error     bool
	expected  string
}

func versionTests() []versionTest {
	const (
		binary = true
		base64 = false
	)
	return []versionTest{
		{
			supported: nil,
			requested: []string{"raw"},
			error:     true,
		},
		{
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
			requested: nil,
			expected:  "",
		},
		{
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
			requested: []string{"v1.raw"},
			error:     true,
		},
		{
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
			requested: []string{"v1.raw", "v1.base64"},
			error:     true,
		}, {
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
			requested: []string{"v1.raw", "raw"},
			expected:  "raw",
		},
		{
			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
			requested: []string{"v1.raw"},
			expected:  "v1.raw",
		},
		{
			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
			requested: []string{"v2.base64"},
			expected:  "v2.base64",
		},
	}
}

func TestVersionedConn(t *testing.T) {
	for i, test := range versionTests() {
		func() {
			supportedProtocols := map[string]ChannelProtocolConfig{}
			for p, binary := range test.supported {
				supportedProtocols[p] = ChannelProtocolConfig{
					Binary:   binary,
					Channels: []ChannelType{ReadWriteChannel},
				}
			}
			conn := NewConn(supportedProtocols)
			// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
			// we use a channel.
			selectedProtocol := make(chan string)
			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
				p, _, _ := conn.Open(w, req)
				selectedProtocol <- p
			}))
			defer s.Close()

			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
			if err != nil {
				t.Fatal(err)
			}
			config.Protocol = test.requested
			client, err := websocket.DialConfig(config)
			if err != nil {
				if !test.error {
					t.Fatalf("test %d: didn't expect error: %v", i, err)
				} else {
					return
				}
			}
			defer client.Close()
			if test.error && err == nil {
				t.Fatalf("test %d: expected an error", i)
			}

			<-conn.ready
			if got, expected := <-selectedProtocol, test.expected; got != expected {
				t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
			}
		}()
	}
}

相关信息

kubernetes 源码目录

相关文章

kubernetes conn 源码

kubernetes doc 源码

kubernetes stream 源码

kubernetes stream_test 源码

0  赞